Skip to content

Commit

Permalink
[refactor] done fixing mypy issues for syft/node
Browse files Browse the repository at this point in the history
  • Loading branch information
khoaguin committed Feb 22, 2024
1 parent 09df3cf commit 81d06d4
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 50 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ repos:
- id: mypy
name: "mypy: syft"
always_run: true
files: "^packages/syft/src/syft/client"
files: "^packages/syft/src/syft/client|^packages/syft/src/syft/node"
args: [
"--follow-imports=skip",
"--ignore-missing-imports",
Expand Down
96 changes: 57 additions & 39 deletions packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type
from typing import Union
import uuid
Expand All @@ -39,6 +38,7 @@
from ..client.api import SyftAPICall
from ..client.api import SyftAPIData
from ..client.api import debox_signed_syftapicall_response
from ..client.client import SyftClient
from ..exceptions.exception import PySyftException
from ..external import OBLV
from ..protocol.data_protocol import PROTOCOL_TYPE
Expand All @@ -63,6 +63,7 @@
from ..service.enclave.enclave_service import EnclaveService
from ..service.job.job_service import JobService
from ..service.job.job_stash import Job
from ..service.job.job_stash import JobStash
from ..service.log.log_service import LogService
from ..service.metadata.metadata_service import MetadataService
from ..service.metadata.node_metadata import NodeMetadataV3
Expand Down Expand Up @@ -101,9 +102,11 @@
from ..service.worker.utils import DEFAULT_WORKER_POOL_NAME
from ..service.worker.utils import create_default_image
from ..service.worker.worker_image_service import SyftWorkerImageService
from ..service.worker.worker_pool import WorkerPool
from ..service.worker.worker_pool_service import SyftWorkerPoolService
from ..service.worker.worker_pool_stash import SyftWorkerPoolStash
from ..service.worker.worker_service import WorkerService
from ..service.worker.worker_stash import WorkerStash
from ..store.blob_storage import BlobStorageConfig
from ..store.blob_storage.on_disk import OnDiskBlobStorageClientConfig
from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig
Expand Down Expand Up @@ -196,15 +199,15 @@ def get_default_worker_pool_name() -> str:
return get_env("DEFAULT_WORKER_POOL_NAME", DEFAULT_WORKER_POOL_NAME)


def get_default_worker_pool_count(node) -> int:
def get_default_worker_pool_count(node: Node) -> int:
return int(
get_env(
"DEFAULT_WORKER_POOL_COUNT", node.queue_config.client_config.n_consumers
)
)


def in_kubernetes() -> Optional[str]:
def in_kubernetes() -> bool:
return get_container_host() == "k8s"


Expand Down Expand Up @@ -242,15 +245,15 @@ def get_syft_worker_uid() -> Optional[str]:


class AuthNodeContextRegistry:
__node_context_registry__: Dict[Tuple, Node] = OrderedDict()
__node_context_registry__: Dict[str, Node] = OrderedDict()

@classmethod
def set_node_context(
cls,
node_uid: Union[UID, str],
context: NodeServiceContext,
user_verify_key: Union[SyftVerifyKey, str],
):
) -> None:
if isinstance(node_uid, str):
node_uid = UID.from_string(node_uid)

Expand Down Expand Up @@ -290,9 +293,9 @@ def __init__(
signing_key: Optional[Union[SyftSigningKey, SigningKey]] = None,
action_store_config: Optional[StoreConfig] = None,
document_store_config: Optional[StoreConfig] = None,
root_email: str = default_root_email,
root_username: str = default_root_username,
root_password: str = default_root_password,
root_email: Optional[str] = default_root_email,
root_username: Optional[str] = default_root_username,
root_password: Optional[str] = default_root_password,
processes: int = 0,
is_subprocess: bool = False,
node_type: Union[str, NodeType] = NodeType.DOMAIN,
Expand Down Expand Up @@ -394,7 +397,7 @@ def __init__(
node=self,
)

self.client_cache = {}
self.client_cache: dict = {}
if isinstance(node_type, str):
node_type = NodeType(node_type)
self.node_type = node_type
Expand Down Expand Up @@ -425,7 +428,7 @@ def __init__(
NodeRegistry.set_node_for(self.id, self)

@property
def runs_in_docker(self):
def runs_in_docker(self) -> bool:
path = "/proc/self/cgroup"
return (
os.path.exists("/.dockerenv")
Expand Down Expand Up @@ -457,14 +460,14 @@ def init_blob_storage(self, config: Optional[BlobStorageConfig] = None) -> None:
remote_profile.profile_name
] = remote_profile

def stop(self):
def stop(self) -> None:
for consumer_list in self.queue_manager.consumers.values():
for c in consumer_list:
c.close()
for p in self.queue_manager.producers.values():
p.close()

def close(self):
def close(self) -> None:
self.stop()

def create_queue_config(
Expand Down Expand Up @@ -493,10 +496,10 @@ def create_queue_config(

return queue_config_

def init_queue_manager(self, queue_config: QueueConfig):
def init_queue_manager(self, queue_config: QueueConfig) -> None:
MessageHandlers = [APICallMessageHandler]
if self.is_subprocess:
return
return None

self.queue_manager = QueueManager(config=queue_config)
for message_handler in MessageHandlers:
Expand Down Expand Up @@ -552,7 +555,7 @@ def add_consumer_for_service(
syft_worker_id: UID,
address: str,
message_handler: AbstractMessageHandler = APICallMessageHandler,
):
) -> None:
consumer: QueueConsumer = self.queue_manager.create_consumer(
message_handler,
address=address,
Expand Down Expand Up @@ -664,7 +667,7 @@ def is_root(self, credentials: SyftVerifyKey) -> bool:
return credentials == self.verify_key

@property
def root_client(self):
def root_client(self) -> SyftClient:
# relative
from ..client.client import PythonConnection

Expand All @@ -673,7 +676,8 @@ def root_client(self):
if isinstance(client_type, SyftError):
return client_type
root_client = client_type(connection=connection, credentials=self.signing_key)
root_client.api.refresh_api_callback()
if root_client.api.refresh_api_callback is not None:
root_client.api.refresh_api_callback()
return root_client

def _find_klasses_pending_for_migration(
Expand Down Expand Up @@ -707,7 +711,7 @@ def _find_klasses_pending_for_migration(

return klasses_to_be_migrated

def find_and_migrate_data(self):
def find_and_migrate_data(self) -> None:
# Track all object type that need migration for document store
context = AuthedServiceContext(
node=self,
Expand Down Expand Up @@ -772,15 +776,15 @@ def find_and_migrate_data(self):
print("Data Migrated to latest version !!!")

@property
def guest_client(self):
def guest_client(self) -> SyftClient:
return self.get_guest_client()

@property
def current_protocol(self) -> List:
data_protocol = get_data_protocol()
return data_protocol.latest_version

def get_guest_client(self, verbose: bool = True):
def get_guest_client(self, verbose: bool = True) -> SyftClient:
# relative
from ..client.client import PythonConnection

Expand All @@ -798,7 +802,8 @@ def get_guest_client(self, verbose: bool = True):
guest_client = client_type(
connection=connection, credentials=SyftSigningKey.generate()
)
guest_client.api.refresh_api_callback()
if guest_client.api.refresh_api_callback is not None:
guest_client.api.refresh_api_callback()
return guest_client

def __repr__(self) -> str:
Expand Down Expand Up @@ -840,7 +845,7 @@ def init_stores(
self,
document_store_config: Optional[StoreConfig] = None,
action_store_config: Optional[StoreConfig] = None,
):
) -> None:
if document_store_config is None:
if self.local_db or (self.processes > 0 and not self.is_subprocess):
client_config = SQLiteStoreClientConfig(path=self.sqlite_path)
Expand Down Expand Up @@ -905,14 +910,14 @@ def init_stores(
self.queue_stash = QueueStash(store=self.document_store)

@property
def job_stash(self):
def job_stash(self) -> JobStash:
return self.get_service("jobservice").stash

@property
def worker_stash(self):
def worker_stash(self) -> WorkerStash:
return self.get_service("workerservice").stash

def _construct_services(self):
def _construct_services(self) -> None:
self.service_path_map = {}

for service_klass in self.services:
Expand Down Expand Up @@ -1135,7 +1140,7 @@ def handle_api_call(
self,
api_call: Union[SyftAPICall, SignedSyftAPICall],
job_id: Optional[UID] = None,
check_call_location=True,
check_call_location: bool = True,
) -> Result[SignedSyftAPICall, Err]:
# Get the result
result = self.handle_api_call_with_unsigned_result(
Expand All @@ -1150,7 +1155,7 @@ def handle_api_call_with_unsigned_result(
self,
api_call: Union[SyftAPICall, SignedSyftAPICall],
job_id: Optional[UID] = None,
check_call_location=True,
check_call_location: bool = True,
) -> Result[Union[QueueItem, SyftObject], Err]:
if self.required_signed_calls and isinstance(api_call, SyftAPICall):
return SyftError(
Expand Down Expand Up @@ -1212,12 +1217,12 @@ def handle_api_call_with_unsigned_result(

def add_action_to_queue(
self,
action,
credentials,
parent_job_id=None,
action: Action,
credentials: SyftVerifyKey,
parent_job_id: Optional[UID] = None,
has_execute_permissions: bool = False,
worker_pool_name: Optional[str] = None,
):
) -> Union[Job, SyftError]:
job_id = UID()
task_uid = UID()
worker_settings = WorkerSettings.from_node(node=self)
Expand Down Expand Up @@ -1267,8 +1272,12 @@ def add_action_to_queue(
)

def add_queueitem_to_queue(
self, queue_item, credentials, action=None, parent_job_id=None
):
self,
queue_item: QueueItem,
credentials: SyftVerifyKey,
action: Optional[Action] = None,
parent_job_id: Optional[UID] = None,
) -> Union[Job, SyftError]:
log_id = UID()
role = self.get_role_for_credentials(credentials=credentials)
context = AuthedServiceContext(node=self, credentials=credentials, role=role)
Expand Down Expand Up @@ -1329,7 +1338,9 @@ def _is_usercode_call_on_owned_kwargs(
user_code_service = self.get_service("usercodeservice")
return user_code_service.is_execution_on_owned_args(api_call.kwargs, context)

def add_api_call_to_queue(self, api_call, parent_job_id=None):
def add_api_call_to_queue(
self, api_call: SyftAPICall, parent_job_id: Optional[UID] = None
) -> Union[Job, SyftError]:
unsigned_call = api_call
if isinstance(api_call, SignedSyftAPICall):
unsigned_call = api_call.message
Expand Down Expand Up @@ -1416,7 +1427,7 @@ def pool_stash(self) -> SyftWorkerPoolStash:
def user_code_stash(self) -> UserCodeStash:
return self.get_service(UserCodeService).stash

def get_default_worker_pool(self):
def get_default_worker_pool(self) -> WorkerPool:
result = self.pool_stash.get_by_name(
credentials=self.verify_key,
pool_name=get_default_worker_pool_name(),
Expand Down Expand Up @@ -1481,6 +1492,7 @@ def create_initial_settings(self, admin_email: str) -> Optional[NodeSettingsV2]:
return None
except Exception as e:
print("create_worker_metadata failed", e)
return None


def create_admin_new(
Expand Down Expand Up @@ -1521,6 +1533,8 @@ def create_admin_new(
except Exception as e:
print("Unable to create new admin", e)

return None


def create_oblv_key_pair(
worker: Node,
Expand All @@ -1544,6 +1558,9 @@ def create_oblv_key_pair(
print(f"Using Existing Public/Private Key pair: {len(oblv_keys_stash)}")
except Exception as e:
print("Unable to create Oblv Keys.", e)
return None

return None


class NodeRegistry:
Expand All @@ -1569,7 +1586,7 @@ def get_all_nodes(cls) -> List[Node]:
return list(cls.__node_registry__.values())


def get_default_worker_tag_by_env(dev_mode=False):
def get_default_worker_tag_by_env(dev_mode: bool = False) -> str:
if in_kubernetes():
return get_default_worker_image()
elif dev_mode:
Expand Down Expand Up @@ -1617,7 +1634,7 @@ def create_default_worker_pool(node: Node) -> Optional[SyftError]:

if isinstance(result, SyftError):
print("Failed to build default worker image: ", result.message)
return
return None

# Create worker pool if it doesn't exists
print(
Expand Down Expand Up @@ -1650,7 +1667,7 @@ def create_default_worker_pool(node: Node) -> Optional[SyftError]:

if isinstance(result, SyftError):
print(f"Default worker pool error. {result.message}")
return
return None

for n in range(worker_to_add_):
container_status = result[n]
Expand All @@ -1659,6 +1676,7 @@ def create_default_worker_pool(node: Node) -> Optional[SyftError]:
f"Failed to create container: Worker: {container_status.worker},"
f"Error: {container_status.error}"
)
return
return None

print("Created default worker pool.")
return None
5 changes: 4 additions & 1 deletion packages/syft/src/syft/node/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import argparse
from typing import Optional

# third party
from hagrid.orchestra import NodeHandle

# relative
from ..client.deploy import Orchestra

Expand All @@ -14,7 +17,7 @@ def str_to_bool(bool_str: Optional[str]) -> bool:
return result


def run():
def run() -> Optional[NodeHandle]:
parser = argparse.ArgumentParser()
parser.add_argument("command", help="command: launch", type=str, default="none")
parser.add_argument(
Expand Down
Loading

0 comments on commit 81d06d4

Please sign in to comment.