diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b94ec680e9e..41cbd8bcaac 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -172,7 +172,7 @@ repos: - id: mypy name: "mypy: syft" always_run: true - files: "^packages/syft/src/syft/client" + files: "^packages/syft/src/syft/client|^packages/syft/src/syft/node" args: [ "--follow-imports=skip", "--ignore-missing-imports", diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index a9b161bc63f..e869c308708 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -18,7 +18,6 @@ from typing import Dict from typing import List from typing import Optional -from typing import Tuple from typing import Type from typing import Union import uuid @@ -39,6 +38,7 @@ from ..client.api import SyftAPICall from ..client.api import SyftAPIData from ..client.api import debox_signed_syftapicall_response +from ..client.client import SyftClient from ..exceptions.exception import PySyftException from ..external import OBLV from ..protocol.data_protocol import PROTOCOL_TYPE @@ -63,6 +63,7 @@ from ..service.enclave.enclave_service import EnclaveService from ..service.job.job_service import JobService from ..service.job.job_stash import Job +from ..service.job.job_stash import JobStash from ..service.log.log_service import LogService from ..service.metadata.metadata_service import MetadataService from ..service.metadata.node_metadata import NodeMetadataV3 @@ -101,9 +102,11 @@ from ..service.worker.utils import DEFAULT_WORKER_POOL_NAME from ..service.worker.utils import create_default_image from ..service.worker.worker_image_service import SyftWorkerImageService +from ..service.worker.worker_pool import WorkerPool from ..service.worker.worker_pool_service import SyftWorkerPoolService from ..service.worker.worker_pool_stash import SyftWorkerPoolStash from ..service.worker.worker_service import WorkerService +from ..service.worker.worker_stash import WorkerStash from ..store.blob_storage import BlobStorageConfig from ..store.blob_storage.on_disk import OnDiskBlobStorageClientConfig from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig @@ -196,7 +199,7 @@ def get_default_worker_pool_name() -> str: return get_env("DEFAULT_WORKER_POOL_NAME", DEFAULT_WORKER_POOL_NAME) -def get_default_worker_pool_count(node) -> int: +def get_default_worker_pool_count(node: Node) -> int: return int( get_env( "DEFAULT_WORKER_POOL_COUNT", node.queue_config.client_config.n_consumers @@ -204,7 +207,7 @@ def get_default_worker_pool_count(node) -> int: ) -def in_kubernetes() -> Optional[str]: +def in_kubernetes() -> bool: return get_container_host() == "k8s" @@ -242,7 +245,7 @@ def get_syft_worker_uid() -> Optional[str]: class AuthNodeContextRegistry: - __node_context_registry__: Dict[Tuple, Node] = OrderedDict() + __node_context_registry__: Dict[str, Node] = OrderedDict() @classmethod def set_node_context( @@ -250,7 +253,7 @@ def set_node_context( node_uid: Union[UID, str], context: NodeServiceContext, user_verify_key: Union[SyftVerifyKey, str], - ): + ) -> None: if isinstance(node_uid, str): node_uid = UID.from_string(node_uid) @@ -290,9 +293,9 @@ def __init__( signing_key: Optional[Union[SyftSigningKey, SigningKey]] = None, action_store_config: Optional[StoreConfig] = None, document_store_config: Optional[StoreConfig] = None, - root_email: str = default_root_email, - root_username: str = default_root_username, - root_password: str = default_root_password, + root_email: Optional[str] = default_root_email, + root_username: Optional[str] = default_root_username, + root_password: Optional[str] = default_root_password, processes: int = 0, is_subprocess: bool = False, node_type: Union[str, NodeType] = NodeType.DOMAIN, @@ -394,7 +397,7 @@ def __init__( node=self, ) - self.client_cache = {} + self.client_cache: dict = {} if isinstance(node_type, str): node_type = NodeType(node_type) self.node_type = node_type @@ -425,7 +428,7 @@ def __init__( NodeRegistry.set_node_for(self.id, self) @property - def runs_in_docker(self): + def runs_in_docker(self) -> bool: path = "/proc/self/cgroup" return ( os.path.exists("/.dockerenv") @@ -457,14 +460,14 @@ def init_blob_storage(self, config: Optional[BlobStorageConfig] = None) -> None: remote_profile.profile_name ] = remote_profile - def stop(self): + def stop(self) -> None: for consumer_list in self.queue_manager.consumers.values(): for c in consumer_list: c.close() for p in self.queue_manager.producers.values(): p.close() - def close(self): + def close(self) -> None: self.stop() def create_queue_config( @@ -493,10 +496,10 @@ def create_queue_config( return queue_config_ - def init_queue_manager(self, queue_config: QueueConfig): + def init_queue_manager(self, queue_config: QueueConfig) -> None: MessageHandlers = [APICallMessageHandler] if self.is_subprocess: - return + return None self.queue_manager = QueueManager(config=queue_config) for message_handler in MessageHandlers: @@ -552,7 +555,7 @@ def add_consumer_for_service( syft_worker_id: UID, address: str, message_handler: AbstractMessageHandler = APICallMessageHandler, - ): + ) -> None: consumer: QueueConsumer = self.queue_manager.create_consumer( message_handler, address=address, @@ -664,7 +667,7 @@ def is_root(self, credentials: SyftVerifyKey) -> bool: return credentials == self.verify_key @property - def root_client(self): + def root_client(self) -> SyftClient: # relative from ..client.client import PythonConnection @@ -673,7 +676,8 @@ def root_client(self): if isinstance(client_type, SyftError): return client_type root_client = client_type(connection=connection, credentials=self.signing_key) - root_client.api.refresh_api_callback() + if root_client.api.refresh_api_callback is not None: + root_client.api.refresh_api_callback() return root_client def _find_klasses_pending_for_migration( @@ -707,7 +711,7 @@ def _find_klasses_pending_for_migration( return klasses_to_be_migrated - def find_and_migrate_data(self): + def find_and_migrate_data(self) -> None: # Track all object type that need migration for document store context = AuthedServiceContext( node=self, @@ -772,7 +776,7 @@ def find_and_migrate_data(self): print("Data Migrated to latest version !!!") @property - def guest_client(self): + def guest_client(self) -> SyftClient: return self.get_guest_client() @property @@ -780,7 +784,7 @@ def current_protocol(self) -> List: data_protocol = get_data_protocol() return data_protocol.latest_version - def get_guest_client(self, verbose: bool = True): + def get_guest_client(self, verbose: bool = True) -> SyftClient: # relative from ..client.client import PythonConnection @@ -798,7 +802,8 @@ def get_guest_client(self, verbose: bool = True): guest_client = client_type( connection=connection, credentials=SyftSigningKey.generate() ) - guest_client.api.refresh_api_callback() + if guest_client.api.refresh_api_callback is not None: + guest_client.api.refresh_api_callback() return guest_client def __repr__(self) -> str: @@ -840,7 +845,7 @@ def init_stores( self, document_store_config: Optional[StoreConfig] = None, action_store_config: Optional[StoreConfig] = None, - ): + ) -> None: if document_store_config is None: if self.local_db or (self.processes > 0 and not self.is_subprocess): client_config = SQLiteStoreClientConfig(path=self.sqlite_path) @@ -905,14 +910,14 @@ def init_stores( self.queue_stash = QueueStash(store=self.document_store) @property - def job_stash(self): + def job_stash(self) -> JobStash: return self.get_service("jobservice").stash @property - def worker_stash(self): + def worker_stash(self) -> WorkerStash: return self.get_service("workerservice").stash - def _construct_services(self): + def _construct_services(self) -> None: self.service_path_map = {} for service_klass in self.services: @@ -1135,7 +1140,7 @@ def handle_api_call( self, api_call: Union[SyftAPICall, SignedSyftAPICall], job_id: Optional[UID] = None, - check_call_location=True, + check_call_location: bool = True, ) -> Result[SignedSyftAPICall, Err]: # Get the result result = self.handle_api_call_with_unsigned_result( @@ -1150,7 +1155,7 @@ def handle_api_call_with_unsigned_result( self, api_call: Union[SyftAPICall, SignedSyftAPICall], job_id: Optional[UID] = None, - check_call_location=True, + check_call_location: bool = True, ) -> Result[Union[QueueItem, SyftObject], Err]: if self.required_signed_calls and isinstance(api_call, SyftAPICall): return SyftError( @@ -1212,12 +1217,12 @@ def handle_api_call_with_unsigned_result( def add_action_to_queue( self, - action, - credentials, - parent_job_id=None, + action: Action, + credentials: SyftVerifyKey, + parent_job_id: Optional[UID] = None, has_execute_permissions: bool = False, worker_pool_name: Optional[str] = None, - ): + ) -> Union[Job, SyftError]: job_id = UID() task_uid = UID() worker_settings = WorkerSettings.from_node(node=self) @@ -1267,8 +1272,12 @@ def add_action_to_queue( ) def add_queueitem_to_queue( - self, queue_item, credentials, action=None, parent_job_id=None - ): + self, + queue_item: QueueItem, + credentials: SyftVerifyKey, + action: Optional[Action] = None, + parent_job_id: Optional[UID] = None, + ) -> Union[Job, SyftError]: log_id = UID() role = self.get_role_for_credentials(credentials=credentials) context = AuthedServiceContext(node=self, credentials=credentials, role=role) @@ -1329,7 +1338,9 @@ def _is_usercode_call_on_owned_kwargs( user_code_service = self.get_service("usercodeservice") return user_code_service.is_execution_on_owned_args(api_call.kwargs, context) - def add_api_call_to_queue(self, api_call, parent_job_id=None): + def add_api_call_to_queue( + self, api_call: SyftAPICall, parent_job_id: Optional[UID] = None + ) -> Union[Job, SyftError]: unsigned_call = api_call if isinstance(api_call, SignedSyftAPICall): unsigned_call = api_call.message @@ -1416,7 +1427,7 @@ def pool_stash(self) -> SyftWorkerPoolStash: def user_code_stash(self) -> UserCodeStash: return self.get_service(UserCodeService).stash - def get_default_worker_pool(self): + def get_default_worker_pool(self) -> WorkerPool: result = self.pool_stash.get_by_name( credentials=self.verify_key, pool_name=get_default_worker_pool_name(), @@ -1481,6 +1492,7 @@ def create_initial_settings(self, admin_email: str) -> Optional[NodeSettingsV2]: return None except Exception as e: print("create_worker_metadata failed", e) + return None def create_admin_new( @@ -1521,6 +1533,8 @@ def create_admin_new( except Exception as e: print("Unable to create new admin", e) + return None + def create_oblv_key_pair( worker: Node, @@ -1544,6 +1558,9 @@ def create_oblv_key_pair( print(f"Using Existing Public/Private Key pair: {len(oblv_keys_stash)}") except Exception as e: print("Unable to create Oblv Keys.", e) + return None + + return None class NodeRegistry: @@ -1569,7 +1586,7 @@ def get_all_nodes(cls) -> List[Node]: return list(cls.__node_registry__.values()) -def get_default_worker_tag_by_env(dev_mode=False): +def get_default_worker_tag_by_env(dev_mode: bool = False) -> str: if in_kubernetes(): return get_default_worker_image() elif dev_mode: @@ -1617,7 +1634,7 @@ def create_default_worker_pool(node: Node) -> Optional[SyftError]: if isinstance(result, SyftError): print("Failed to build default worker image: ", result.message) - return + return None # Create worker pool if it doesn't exists print( @@ -1650,7 +1667,7 @@ def create_default_worker_pool(node: Node) -> Optional[SyftError]: if isinstance(result, SyftError): print(f"Default worker pool error. {result.message}") - return + return None for n in range(worker_to_add_): container_status = result[n] @@ -1659,6 +1676,7 @@ def create_default_worker_pool(node: Node) -> Optional[SyftError]: f"Failed to create container: Worker: {container_status.worker}," f"Error: {container_status.error}" ) - return + return None print("Created default worker pool.") + return None diff --git a/packages/syft/src/syft/node/run.py b/packages/syft/src/syft/node/run.py index 3b8376fb4e1..10aa942a498 100644 --- a/packages/syft/src/syft/node/run.py +++ b/packages/syft/src/syft/node/run.py @@ -2,6 +2,9 @@ import argparse from typing import Optional +# third party +from hagrid.orchestra import NodeHandle + # relative from ..client.deploy import Orchestra @@ -14,7 +17,7 @@ def str_to_bool(bool_str: Optional[str]) -> bool: return result -def run(): +def run() -> Optional[NodeHandle]: parser = argparse.ArgumentParser() parser.add_argument("command", help="command: launch", type=str, default="none") parser.add_argument( diff --git a/packages/syft/src/syft/node/server.py b/packages/syft/src/syft/node/server.py index fc7bbb2bc8d..eabcacc49c1 100644 --- a/packages/syft/src/syft/node/server.py +++ b/packages/syft/src/syft/node/server.py @@ -81,7 +81,7 @@ def run_uvicorn( queue_port: Optional[int], create_producer: bool, n_consumers: int, -): +) -> None: async def _run_uvicorn( name: str, node_type: Enum, @@ -90,7 +90,7 @@ async def _run_uvicorn( reset: bool, dev_mode: bool, node_side_type: Enum, - ): + ) -> None: if node_type not in worker_classes: raise NotImplementedError(f"node_type: {node_type} is not supported") worker_class = worker_classes[node_type] @@ -205,7 +205,7 @@ def serve_node( ), ) - def stop(): + def stop() -> None: print(f"Stopping {name}") server_process.terminate() server_process.join(3) @@ -214,7 +214,7 @@ def stop(): server_process.kill() print("killed") - def start(): + def start() -> None: print(f"Starting {name} server on {host}:{port}") server_process.start() diff --git a/packages/syft/src/syft/node/worker_settings.py b/packages/syft/src/syft/node/worker_settings.py index 106e2e94821..509488ac5db 100644 --- a/packages/syft/src/syft/node/worker_settings.py +++ b/packages/syft/src/syft/node/worker_settings.py @@ -2,6 +2,7 @@ from __future__ import annotations # stdlib +from typing import Callable from typing import Optional # third party @@ -55,9 +56,9 @@ class WorkerSettings(SyftObject): blob_store_config: Optional[BlobStorageConfig] queue_config: Optional[QueueConfig] - @staticmethod - def from_node(node: AbstractNode) -> Self: - return WorkerSettings( + @classmethod + def from_node(cls, node: AbstractNode) -> Self: + return cls( id=node.id, name=node.name, node_type=node.node_type, @@ -74,14 +75,14 @@ def from_node(node: AbstractNode) -> Self: @migrate(WorkerSettings, WorkerSettingsV1) -def downgrade_workersettings_v2_to_v1(): +def downgrade_workersettings_v2_to_v1() -> list[Callable]: return [ drop(["queue_config"]), ] @migrate(WorkerSettingsV1, WorkerSettings) -def upgrade_workersettings_v1_to_v2(): +def upgrade_workersettings_v1_to_v2() -> list[Callable]: # relative from ..service.queue.zmq_queue import ZMQQueueConfig