From d188d42a873ca2afe2a4e0694dfa0a6520c9a3ce Mon Sep 17 00:00:00 2001 From: khoaguin Date: Wed, 7 Feb 2024 21:47:10 +0700 Subject: [PATCH 01/42] [refactor] done fixing mypy issues for `syft/service/request` --- .pre-commit-config.yaml | 2 +- packages/syft/src/syft/serde/third_party.py | 4 +- .../syft/src/syft/service/request/request.py | 115 +++++++++++------- .../syft/service/request/request_service.py | 54 ++++---- 4 files changed, 103 insertions(+), 72 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81b29307e8a..40638e839fc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -172,7 +172,7 @@ repos: - id: mypy name: "mypy: syft" always_run: true - files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util/env.py|^packages/syft/src/syft/util/logger.py|^packages/syft/src/syft/util/markdown.py|^packages/syft/src/syft/util/notebook_ui/notebook_addons.py" + files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util/env.py|^packages/syft/src/syft/util/logger.py|^packages/syft/src/syft/util/markdown.py|^packages/syft/src/syft/util/notebook_ui/notebook_addons.py|^packages/syft/src/syft/service/request" #files: "^packages/syft/src/syft/serde" args: [ "--follow-imports=skip", diff --git a/packages/syft/src/syft/serde/third_party.py b/packages/syft/src/syft/serde/third_party.py index 1abfe2d9cdc..2d70250cbe6 100644 --- a/packages/syft/src/syft/serde/third_party.py +++ b/packages/syft/src/syft/serde/third_party.py @@ -96,8 +96,8 @@ def deserialize_dataframe(buf: bytes) -> DataFrame: def deserialize_series(blob: bytes) -> Series: - df = DataFrame.from_dict(deserialize(blob, from_bytes=True)) - return df[df.columns[0]] + df: DataFrame = DataFrame.from_dict(deserialize(blob, from_bytes=True)) + return Series(df[df.columns[0]]) recursive_serde_register( diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 699e577e2b9..e8d3bd1a679 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -19,6 +19,7 @@ # relative from ...abstract_node import NodeSideType from ...client.api import APIRegistry +from ...client.client import SyftClient from ...custom_worker.config import WorkerConfig from ...custom_worker.k8s import IN_KUBERNETES from ...node.credentials import SyftVerifyKey @@ -77,7 +78,7 @@ class Change(SyftObject): linked_obj: Optional[LinkedObject] def is_type(self, type_: type) -> bool: - return self.linked_obj and type_ == self.linked_obj.object_type + return (self.linked_obj is not None) and (type_ == self.linked_obj.object_type) @serializable() @@ -182,7 +183,7 @@ def apply(self, context: ChangeContext) -> Result[SyftSuccess, SyftError]: def undo(self, context: ChangeContext) -> Result[SyftSuccess, SyftError]: return self._run(context=context, apply=False) - def __repr_syft_nested__(self): + def __repr_syft_nested__(self) -> str: return f"Apply {self.apply_permission_type} to \ {self.linked_obj.object_type.__canonical_name__}:{self.linked_obj.object_uid.short()}" @@ -259,7 +260,7 @@ def apply(self, context: ChangeContext) -> Result[SyftSuccess, SyftError]: def undo(self, context: ChangeContext) -> Result[SyftSuccess, SyftError]: return self._run(context=context, apply=False) - def __repr_syft_nested__(self): + def __repr_syft_nested__(self) -> str: return f"Create Image for Config: {self.config} with tag: {self.tag}" @@ -321,7 +322,7 @@ def apply(self, context: ChangeContext) -> Result[SyftSuccess, SyftError]: def undo(self, context: ChangeContext) -> Result[SyftSuccess, SyftError]: return self._run(context=context, apply=False) - def __repr_syft_nested__(self): + def __repr_syft_nested__(self) -> str: return ( f"Create Worker Pool '{self.pool_name}' for Image with id {self.image_uid}" ) @@ -364,7 +365,7 @@ def _repr_html_(self) -> Any: updated_at_line += ( f"

Created by: {self.requesting_user_name}

" ) - str_changes = [] + str_changes_ = [] for change in self.changes: str_change = ( change.__repr_syft_nested__() @@ -372,8 +373,8 @@ def _repr_html_(self) -> Any: else type(change) ) str_change = f"{str_change}. " - str_changes.append(str_change) - str_changes = "\n".join(str_changes) + str_changes_.append(str_change) + str_changes = "\n".join(str_changes_) api = APIRegistry.api_for( self.node_uid, self.syft_client_verify_key, @@ -420,7 +421,7 @@ def _repr_html_(self) -> Any: """ - def _coll_repr_(self): + def _coll_repr_(self) -> Dict[str, Union[str, Dict[str, str]]]: if self.status == RequestStatus.APPROVED: badge_color = "badge-green" elif self.status == RequestStatus.PENDING: @@ -473,7 +474,7 @@ def current_change_state(self) -> Dict[UID, bool]: return change_applied_map @property - def icon(self): + def icon(self) -> str: return REQUEST_ICON @property @@ -496,7 +497,7 @@ def approve( disable_warnings: bool = False, approve_nested: bool = False, **kwargs: dict, - ): + ) -> Result[SyftSuccess, SyftError]: api = APIRegistry.api_for( self.node_uid, self.syft_client_verify_key, @@ -529,7 +530,7 @@ def approve( print(f"Approving request for domain {api.node_name}") return api.services.request.apply(self.id, **kwargs) - def deny(self, reason: str): + def deny(self, reason: str) -> Union[SyftSuccess, SyftError]: """Denies the particular request. Args: @@ -541,7 +542,7 @@ def deny(self, reason: str): ) return api.services.request.undo(uid=self.id, reason=reason) - def approve_with_client(self, client): + def approve_with_client(self, client: SyftClient) -> Result[SyftSuccess, SyftError]: print(f"Approving request for domain {client.name}") return client.api.services.request.apply(self.id) @@ -624,7 +625,9 @@ def _get_latest_or_create_job(self) -> Union[Job, SyftError]: return job - def accept_by_depositing_result(self, result: Any, force: bool = False): + def accept_by_depositing_result( + self, result: Any, force: bool = False + ) -> Union[SyftError, SyftSuccess]: # this code is extremely brittle because its a work around that relies on # the type of request being very specifically tied to code which needs approving @@ -647,10 +650,15 @@ def accept_by_depositing_result(self, result: Any, force: bool = False): change = self.changes[0] if not change.is_type(UserCode): - raise TypeError( - f"accept_by_depositing_result can only be run on {UserCode} not " - f"{change.linked_obj.object_type}" - ) + if change.linked_obj is not None: + raise TypeError( + f"accept_by_depositing_result can only be run on {UserCode} not " + f"{change.linked_obj.object_type}" + ) + else: + raise TypeError( + f"accept_by_depositing_result can only be run on {UserCode}" + ) if not type(change) == UserCodeStatusChange: raise TypeError( f"accept_by_depositing_result can only be run on {UserCodeStatusChange} not " @@ -743,7 +751,9 @@ def accept_by_depositing_result(self, result: Any, force: bool = False): return SyftSuccess(message="Request submitted for updating result.") - def sync_job(self, job_info: JobInfo, **kwargs) -> Result[SyftSuccess, SyftError]: + def sync_job( + self, job_info: JobInfo, **kwargs: Any + ) -> Result[SyftSuccess, SyftError]: if job_info.includes_result: return SyftError( message="This JobInfo includes a Result. Please use Request.accept_by_depositing_result instead." @@ -859,25 +869,29 @@ class ObjectMutation(Change): __repr_attrs__ = ["linked_obj", "attr_name"] - def mutate(self, obj: Any, value: Optional[Any]) -> Any: + def mutate(self, obj: Any, value: Optional[Any] = None) -> Any: # check if attribute is a property setter first # this seems necessary for pydantic types attr = getattr(type(obj), self.attr_name, None) if inspect.isdatadescriptor(attr): + assert hasattr(attr, "fget") and hasattr( + attr, "fset" + ), "attr must have fget and fset" self.previous_value = attr.fget(obj) attr.fset(obj, value) - else: self.previous_value = getattr(obj, self.attr_name, None) setattr(obj, self.attr_name, value) return obj - def __repr_syft_nested__(self): + def __repr_syft_nested__(self) -> str: return f"Mutate {self.attr_name} to {self.value}" def _run( self, context: ChangeContext, apply: bool ) -> Result[SyftSuccess, SyftError]: + if self.linked_obj is None: + return Err(SyftError(message=f"{self}'s linked object is None")) try: obj = self.linked_obj.resolve_with_context(context) if obj.is_err(): @@ -937,7 +951,7 @@ def valid(self) -> Union[SyftSuccess, SyftError]: @staticmethod def from_obj( linked_obj: LinkedObject, attr_name: str, value: Optional[Enum] = None - ) -> Self: + ) -> "EnumMutation": enum_type = type_for_field(linked_obj.object_type, attr_name) return EnumMutation( linked_obj=linked_obj, @@ -954,12 +968,14 @@ def _run( valid = self.valid if not valid: return Err(valid) + if self.linked_obj is None: + return Err(SyftError(message=f"{self}'s linked object is None")) obj = self.linked_obj.resolve_with_context(context) if obj.is_err(): - return SyftError(message=obj.err()) + return Err(SyftError(message=obj.err())) obj = obj.ok() if apply: - obj = self.mutate(obj) + obj = self.mutate(obj=obj) self.linked_obj.update_with_context(context, obj) else: @@ -975,7 +991,7 @@ def apply(self, context: ChangeContext) -> Result[SyftSuccess, SyftError]: def undo(self, context: ChangeContext) -> Result[SyftSuccess, SyftError]: return self._run(context=context, apply=False) - def __repr_syft_nested__(self): + def __repr_syft_nested__(self) -> str: return f"Mutate {self.enum_type} to {self.value}" @property @@ -1019,12 +1035,12 @@ class UserCodeStatusChange(Change): ] @property - def code(self): + def code(self) -> Optional[SyftObject]: return self.link @property - def codes(self): - def recursive_code(node): + def codes(self) -> List: + def recursive_code(node: Any) -> List: codes = [] for _, (obj, new_node) in node.items(): codes.append(obj.resolve) @@ -1032,33 +1048,46 @@ def recursive_code(node): return codes codes = [self.link] - codes.extend(recursive_code(self.link.nested_codes)) + if self.link is not None: + codes.extend(recursive_code(self.link.nested_codes)) + return codes - def nested_repr(self, node=None, level=0): + def nested_repr(self, node: Optional[Any] = None, level: int = 0) -> str: msg = "" - if node is None: + if node is None and self.link is not None: node = self.link.nested_codes + if node is None: + return msg for service_func_name, (_, new_node) in node.items(): msg = "├──" + "──" * level + f"{service_func_name}
" msg += self.nested_repr(node=new_node, level=level + 1) return msg - def __repr_syft_nested__(self): - msg = f"Request to change {self.link.service_func_name} (Pool Id: {self.link.worker_pool_name}) " - msg += "to permission RequestStatus.APPROVED" - if self.nested_solved: - if self.link.nested_codes == {}: - msg += ". No nested requests" + def __repr_syft_nested__(self) -> str: + if self.link is not None: + msg = ( + f"Request to change {self.link.service_func_name} " + f"(Pool Id: {self.link.worker_pool_name}) " + ) + msg += "to permission RequestStatus.APPROVED" + if self.nested_solved: + if self.link.nested_codes == {}: + msg += ". No nested requests" + else: + msg += ".

This change requests the following nested functions calls:
" + msg += self.nested_repr() else: - msg += ".

This change requests the following nested functions calls:
" - msg += self.nested_repr() + msg += ". Nested Requests not resolved" else: - msg += ". Nested Requests not resolved" + msg = f"LinkedObject of {self} is None." return msg def _repr_markdown_(self) -> str: link = self.link + if link is None: + return f"{self}'s linked object is None" + input_policy_type = ( link.input_policy_type.__canonical_name__ if link.input_policy_type is not None @@ -1126,7 +1155,7 @@ def mutate(self, obj: UserCode, context: ChangeContext, undo: bool) -> Any: return obj return res - def is_enclave_request(self, user_code: UserCode): + def is_enclave_request(self, user_code: UserCode) -> bool: return ( user_code.is_enclave_code is not None and self.value == UserCodeStatus.APPROVED @@ -1187,14 +1216,14 @@ def link(self) -> Optional[SyftObject]: @migrate(UserCodeStatusChange, UserCodeStatusChangeV1) -def downgrade_usercodestatuschange_v2_to_v1(): +def downgrade_usercodestatuschange_v2_to_v1() -> List[Callable]: return [ drop("nested_solved"), ] @migrate(UserCodeStatusChangeV1, UserCodeStatusChange) -def upgrade_usercodestatuschange_v1_to_v2(): +def upgrade_usercodestatuschange_v1_to_v2() -> List[Callable]: return [ make_set_default("nested_solved", True), ] diff --git a/packages/syft/src/syft/service/request/request_service.py b/packages/syft/src/syft/service/request/request_service.py index b9bfe8761e6..3635ce21960 100644 --- a/packages/syft/src/syft/service/request/request_service.py +++ b/packages/syft/src/syft/service/request/request_service.py @@ -19,6 +19,7 @@ from ..action.action_permissions import ActionPermission from ..code.user_code import UserCode from ..context import AuthedServiceContext +from ..context import NodeServiceContext from ..notification.notification_service import CreateNotification from ..notification.notification_service import NotificationService from ..notification.notifications import Notification @@ -104,7 +105,7 @@ def submit( print("Failed to submit Request", e) raise e - def expand_node(self, context: AuthedServiceContext, code_obj: UserCode): + def expand_node(self, context: AuthedServiceContext, code_obj: UserCode) -> Dict: user_code_service = context.node.get_service("usercodeservice") nested_requests = user_code_service.solve_nested_requests(context, code_obj) @@ -122,7 +123,9 @@ def expand_node(self, context: AuthedServiceContext, code_obj: UserCode): return new_nested_requests - def resolve_nested_requests(self, context, request): + def resolve_nested_requests( + self, context: NodeServiceContext, request: Request + ) -> Request: # TODO: change this if we have more UserCode Changes if len(request.changes) != 1: return request @@ -133,7 +136,7 @@ def resolve_nested_requests(self, context, request): return request code_obj = change.linked_obj.resolve_with_context(context=context).ok() # recursively check what other UserCodes to approve - nested_requests: Dict[str : Tuple[LinkedObject, Dict]] = self.expand_node( + nested_requests: Dict[str, Tuple[LinkedObject, Dict]] = self.expand_node( context, code_obj ) if isinstance(nested_requests, Err): @@ -161,33 +164,31 @@ def get_all_info( context: AuthedServiceContext, page_index: Optional[int] = 0, page_size: Optional[int] = 0, - ) -> Union[List[RequestInfo], SyftError]: - """Get a Dataset""" + ) -> Union[List[List[RequestInfo]], List[RequestInfo], SyftError]: + """Get the information of all requests""" result = self.stash.get_all(context.credentials) + if result.is_err(): + return SyftError(message=result.err()) + method = context.node.get_service_method(UserService.get_by_verify_key) get_message = context.node.get_service_method(NotificationService.filter_by_obj) - requests = [] - if result.is_ok(): - for req in result.ok(): - user = method(req.requesting_user_verify_key).to(UserView) - message = get_message(context=context, obj_uid=req.id) - requests.append( - RequestInfo(user=user, request=req, notification=message) - ) - - # If chunk size is defined, then split list into evenly sized chunks - if page_size: - requests = [ - requests[i : i + page_size] - for i in range(0, len(requests), page_size) - ] - # Return the proper slice using chunk_index - requests = requests[page_index] - + requests: List[RequestInfo] = [] + for req in result.ok(): + user = method(req.requesting_user_verify_key).to(UserView) + message = get_message(context=context, obj_uid=req.id) + requests.append(RequestInfo(user=user, request=req, notification=message)) + if not page_size: return requests - return SyftError(message=result.err()) + # If chunk size is defined, then split list into evenly sized chunks + chunked_requests: List[List[RequestInfo]] = [ + requests[i : i + page_size] for i in range(0, len(requests), page_size) + ] + if page_index: + return chunked_requests[page_index] + else: + return chunked_requests @service_method(path="request.add_changes", name="add_changes") def add_changes( @@ -223,8 +224,9 @@ def filter_all_info( requests = [ requests[i : i + page_size] for i in range(0, len(requests), page_size) ] - # Return the proper slice using chunk_index - requests = requests[page_index] + if page_index is not None: + # Return the proper slice using chunk_index + requests = requests[page_index] return requests From 71dc2c34e9fd9fe4b79128736fce3ed439182661 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Wed, 7 Feb 2024 22:46:30 +0700 Subject: [PATCH 02/42] [refactor] checking if `attr` has both `fget` and `fset` methods, then call the methods instead of using `assert` since assert statements are removed when compiling Python code to optimized bytecode, causing security issue --- packages/syft/src/syft/service/request/request.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index e8d3bd1a679..92fa935fd9a 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -874,11 +874,9 @@ def mutate(self, obj: Any, value: Optional[Any] = None) -> Any: # this seems necessary for pydantic types attr = getattr(type(obj), self.attr_name, None) if inspect.isdatadescriptor(attr): - assert hasattr(attr, "fget") and hasattr( - attr, "fset" - ), "attr must have fget and fset" - self.previous_value = attr.fget(obj) - attr.fset(obj, value) + if hasattr(attr, "fget") and hasattr(attr, "fset"): + self.previous_value = attr.fget(obj) + attr.fset(obj, value) else: self.previous_value = getattr(obj, self.attr_name, None) setattr(obj, self.attr_name, value) From 4a38df526bf1f3bad2f4725125f6a46a85fc45b9 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Thu, 15 Feb 2024 11:42:10 +0700 Subject: [PATCH 03/42] [refactor] done fixing mypy issues for `syft/service/code_history` --- .../syft/service/code_history/code_history.py | 28 ++++++------ .../code_history/code_history_service.py | 45 ++++++++++++------- 2 files changed, 43 insertions(+), 30 deletions(-) diff --git a/packages/syft/src/syft/service/code_history/code_history.py b/packages/syft/src/syft/service/code_history/code_history.py index 0b27d316fbd..ed9b2655a75 100644 --- a/packages/syft/src/syft/service/code_history/code_history.py +++ b/packages/syft/src/syft/service/code_history/code_history.py @@ -4,6 +4,7 @@ from typing import Dict from typing import List from typing import Optional +from typing import Union # relative from ...client.api import APIRegistry @@ -36,7 +37,7 @@ class CodeHistory(SyftObject): __attr_searchable__ = ["user_verify_key", "service_func_name"] - def add_code(self, code: UserCode, comment: Optional[str] = None): + def add_code(self, code: UserCode, comment: Optional[str] = None) -> None: self.user_code_history.append(code.id) if comment is None: comment = "" @@ -54,10 +55,10 @@ class CodeHistoryView(SyftObject): service_func_name: str comment_history: List[str] = [] - def _coll_repr_(self): + def _coll_repr_(self) -> Dict[str, int]: return {"Number of versions": len(self.user_code_history)} - def _repr_html_(self): + def _repr_html_(self) -> str: rows = get_repr_values_table(self.user_code_history, True) for i, r in enumerate(rows): r["Version"] = f"v{i}" @@ -69,14 +70,13 @@ def _repr_html_(self): # rows = sorted(rows, key=lambda x: x["Version"]) return create_table_template(rows, "CodeHistory", table_icon=None) - def __getitem__(self, index: int): + def __getitem__(self, index: int) -> Union[UserCode, SyftError]: api = APIRegistry.api_for(self.syft_node_location, self.syft_client_verify_key) - if api.user_role.value >= ServiceRole.DATA_OWNER.value: - if index < 0: - return SyftError( - message="For security concerns we do not allow negative indexing. \ - Try using absolute values when indexing" - ) + if api.user_role.value >= ServiceRole.DATA_OWNER.value and index < 0: + return SyftError( + message="For security concerns we do not allow negative indexing. \ + Try using absolute values when indexing" + ) return self.user_code_history[index] @@ -89,7 +89,7 @@ class CodeHistoriesDict(SyftObject): id: UID code_versions: Dict[str, CodeHistoryView] = {} - def _repr_html_(self): + def _repr_html_(self) -> str: return f""" {self.code_versions._repr_html_()} """ @@ -120,14 +120,14 @@ class UsersCodeHistoriesDict(SyftObject): __repr_attrs__ = ["available_keys"] @property - def available_keys(self): + def available_keys(self) -> str: return json.dumps(self.user_dict, sort_keys=True, indent=4) - def __getitem__(self, key: int): + def __getitem__(self, key: int) -> Union[CodeHistoriesDict, SyftError]: api = APIRegistry.api_for(self.node_uid, self.syft_client_verify_key) return api.services.code_history.get_history_for_user(key) - def _repr_html_(self): + def _repr_html_(self) -> str: rows = [] for user, funcs in self.user_dict.items(): rows += [{"user": user, "UserCodes": funcs}] diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index af91e61bc76..29212ea390f 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -4,6 +4,7 @@ from typing import Union # relative +from ...node.credentials import SyftVerifyKey from ...serde.serializable import serializable from ...store.document_store import DocumentStore from ...types.uid import UID @@ -47,18 +48,17 @@ def submit_version( ) -> Union[SyftSuccess, SyftError]: user_code_service = context.node.get_service("usercodeservice") - final_result = None if isinstance(code, SubmitUserCode): result = user_code_service._submit(context=context, code=code) if result.is_err(): return SyftError(message=str(result.err())) - code: UserCode = result.ok() + code = result.ok() elif isinstance(code, UserCode): result = user_code_service.get_by_uid(context=context, uid=code.id) if isinstance(result, SyftError): return result - code: UserCode = result + code = result result = self.stash.get_by_service_func_name_and_verify_key( credentials=context.credentials, @@ -69,7 +69,7 @@ def submit_version( if result.is_err(): return SyftError(message=result.err()) - code_history: CodeHistory = result.ok() + code_history: Optional[CodeHistory] = result.ok() if code_history is None: code_history = CodeHistory( @@ -87,10 +87,7 @@ def submit_version( if result.is_err(): return SyftError(message=result.err()) - if final_result is None: - return SyftSuccess(message="Code version submit success") - else: - return final_result + return SyftSuccess(message="Code version submit success") @service_method( path="code_history.get_all", name="get_all", roles=DATA_SCIENTIST_ROLE_LEVEL @@ -118,7 +115,9 @@ def get_code_by_uid( return SyftError(message=result.err()) @service_method(path="code_history.delete", name="delete") - def delete(self, context: AuthedServiceContext, uid: UID): + def delete( + self, context: AuthedServiceContext, uid: UID + ) -> Union[SyftSuccess, SyftError]: result = self.stash.delete_by_uid(context.credentials, uid) if result.is_ok(): return result.ok() @@ -126,14 +125,16 @@ def delete(self, context: AuthedServiceContext, uid: UID): return SyftError(message=result.err()) def fetch_histories_for_user( - self, context: AuthedServiceContext, user_verify_key - ) -> CodeHistoriesDict: + self, context: AuthedServiceContext, user_verify_key: SyftVerifyKey + ) -> Union[CodeHistoriesDict, SyftError]: result = self.stash.get_by_verify_key( credentials=context.credentials, user_verify_key=user_verify_key ) + if context.node is None: + return SyftError(message=f"context {context}'s node is None") user_code_service = context.node.get_service("usercodeservice") - def get_code(uid): + def get_code(uid: UID) -> Union[UserCode, SyftError]: return user_code_service.get_by_uid(context=context, uid=uid) if result.is_ok(): @@ -159,7 +160,9 @@ def get_code(uid): name="get_history", roles=DATA_SCIENTIST_ROLE_LEVEL, ) - def get_histories_for_current_user(self, context: AuthedServiceContext): + def get_histories_for_current_user( + self, context: AuthedServiceContext + ) -> Union[CodeHistoriesDict, SyftError]: return self.fetch_histories_for_user( context=context, user_verify_key=context.credentials ) @@ -169,7 +172,9 @@ def get_histories_for_current_user(self, context: AuthedServiceContext): name="get_history_for_user", roles=DATA_OWNER_ROLE_LEVEL, ) - def get_history_for_user(self, context: AuthedServiceContext, email: str): + def get_history_for_user( + self, context: AuthedServiceContext, email: str + ) -> Union[CodeHistoriesDict, SyftError]: user_service = context.node.get_service("userservice") result = user_service.stash.get_by_email( credentials=context.credentials, email=email @@ -186,12 +191,16 @@ def get_history_for_user(self, context: AuthedServiceContext, email: str): name="get_histories", roles=DATA_OWNER_ROLE_LEVEL, ) - def get_histories_group_by_user(self, context: AuthedServiceContext): + def get_histories_group_by_user( + self, context: AuthedServiceContext + ) -> Union[UsersCodeHistoriesDict, SyftError]: result = self.stash.get_all(credentials=context.credentials) if result.is_err(): return SyftError(message=result.err()) code_histories: List[CodeHistory] = result.ok() + if context.node is None: + return SyftError(message=f"context {context}'s node is None") user_service = context.node.get_service("userservice") result = user_service.stash.get_all(context.credentials) if result.is_err(): @@ -223,7 +232,9 @@ def get_by_func_name_and_user_email( service_func_name: str, user_email: str, user_id: UID, - ) -> Union[SyftSuccess, SyftError]: + ) -> Union[List[CodeHistory], SyftError]: + if context.node is None: + return SyftError(message=f"context {context}'s node is None") user_service = context.node.get_service("userservice") user_verify_key = user_service.user_verify_key(user_email) @@ -240,3 +251,5 @@ def get_by_func_name_and_user_email( result = self.stash.find_all(credentials=context.credentials, **kwargs) if result.is_err(): # or len(result) > 1 return result + + return result.ok() From b6bd4899d0971a7a4839fa69b9caf4880951159d Mon Sep 17 00:00:00 2001 From: khoaguin Date: Thu, 15 Feb 2024 15:05:51 +0700 Subject: [PATCH 04/42] [refactor] done fixing mypy issues for `service/project/project_service.py` --- packages/syft/src/syft/client/api.py | 4 +- .../code_history/code_history_service.py | 6 +- .../syft/service/project/project_service.py | 120 ++++++++++-------- 3 files changed, 75 insertions(+), 55 deletions(-) diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index 92538709083..f56e3a41a81 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -87,7 +87,9 @@ def set_api_for( cls.__api_registry__[key] = api @classmethod - def api_for(cls, node_uid: UID, user_verify_key: SyftVerifyKey) -> SyftAPI: + def api_for( + cls, node_uid: UID, user_verify_key: SyftVerifyKey + ) -> Optional[SyftAPI]: key = (node_uid, user_verify_key) return cls.__api_registry__.get(key, None) diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index 29212ea390f..f2405ba2cd8 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -46,8 +46,10 @@ def submit_version( code: Union[SubmitUserCode, UserCode], comment: Optional[str] = None, ) -> Union[SyftSuccess, SyftError]: - user_code_service = context.node.get_service("usercodeservice") + if context.node is None: + return SyftError(message=f"context {context}'s node is None") + user_code_service = context.node.get_service("usercodeservice") if isinstance(code, SubmitUserCode): result = user_code_service._submit(context=context, code=code) if result.is_err(): @@ -175,6 +177,8 @@ def get_histories_for_current_user( def get_history_for_user( self, context: AuthedServiceContext, email: str ) -> Union[CodeHistoriesDict, SyftError]: + if context.node is None: + return SyftError(message=f"context {context}'s node is None") user_service = context.node.get_service("userservice") result = user_service.stash.get_by_email( credentials=context.credentials, email=email diff --git a/packages/syft/src/syft/service/project/project_service.py b/packages/syft/src/syft/service/project/project_service.py index 0cef7484b9a..00c4613fd51 100644 --- a/packages/syft/src/syft/service/project/project_service.py +++ b/packages/syft/src/syft/service/project/project_service.py @@ -47,6 +47,8 @@ def __init__(self, store: DocumentStore) -> None: def can_create_project( self, context: AuthedServiceContext ) -> Union[bool, SyftError]: + if context.node is None: + return SyftError(message=f"context {context}'s node is None") user_service = context.node.get_service("userservice") role = user_service.get_role_for_credentials(credentials=context.credentials) if role == ServiceRole.DATA_SCIENTIST: @@ -67,6 +69,9 @@ def create_project( if isinstance(check_role, SyftError): return check_role + if context.node is None: + return SyftError(message=f"context {context}'s node is None") + try: # Check if the project with given id already exists project_id_check = self.stash.get_by_uid( @@ -111,9 +116,14 @@ def create_project( # for the leader node, as it does not have route information to itself # we rely on the data scientist to provide the route # the route is then validated by the leader - leader_node_peer = project.leader_node_route.validate_with_context( - context=context - ) + if project.leader_node_route is not None: + leader_node_peer = project.leader_node_route.validate_with_context( + context=context + ) + else: + return SyftError( + message=f"project {project}'s leader_node_route is None" + ) project_obj.leader_node_peer = leader_node_peer @@ -144,44 +154,41 @@ def add_event( self, context: AuthedServiceContext, project_event: ProjectEvent ) -> Union[SyftSuccess, SyftError]: """To add events to a projects""" + if context.node is None: + return SyftError(message=f"context {context}'s node is None") + # Event object should be received from the leader of the project # retrieve the project object by node verify key project_obj = self.stash.get_by_uid( context.node.verify_key, uid=project_event.project_id ) + if project_obj.is_err(): + return SyftError(message=str(project_obj.err())) - if project_obj.is_ok(): - project: Project = project_obj.ok() - if project.state_sync_leader.verify_key == context.node.verify_key: - return SyftError( - message="Project Events should be passed to leader by broadcast endpoint" - ) - if context.credentials != project.state_sync_leader.verify_key: - return SyftError( - message="Only the leader of the project can add events" - ) - - project.events.append(project_event) - project.event_id_hashmap[project_event.id] = project_event - - message_result = self.check_for_project_request( - project, project_event, context + project: Project = project_obj.ok() + if project.state_sync_leader.verify_key == context.node.verify_key: + return SyftError( + message="Project Events should be passed to leader by broadcast endpoint" ) - if isinstance(message_result, SyftError): - return message_result + if context.credentials != project.state_sync_leader.verify_key: + return SyftError(message="Only the leader of the project can add events") - # updating the project object using root verify key of node - result = self.stash.update(context.node.verify_key, project) + project.events.append(project_event) + project.event_id_hashmap[project_event.id] = project_event - if result.is_err(): - return SyftError(message=str(result.err())) - return SyftSuccess( - message=f"Project event {project_event.id} added successfully " - ) + message_result = self.check_for_project_request(project, project_event, context) + if isinstance(message_result, SyftError): + return message_result - if project_obj.is_err(): - return SyftError(message=str(project_obj.err())) + # updating the project object using root verify key of node + result = self.stash.update(context.node.verify_key, project) + + if result.is_err(): + return SyftError(message=str(result.err())) + return SyftSuccess( + message=f"Project event {project_event.id} added successfully " + ) @service_method( path="project.broadcast_event", @@ -195,6 +202,8 @@ def broadcast_event( # Only the leader of the project could add events to the projects # Any Event to be added to the project should be sent to the leader of the project # The leader broadcasts the event to all the members of the project + if context.node is None: + return SyftError(message=f"context {context}'s node is None") project_obj = self.stash.get_by_uid( context.node.verify_key, uid=project_event.project_id @@ -204,16 +213,16 @@ def broadcast_event( return SyftError(message=str(project_obj.err())) project = project_obj.ok() - if not project.has_permission(context.credentials): return SyftError(message="User does not have permission to add events") - project: Project = project_obj.ok() if project.state_sync_leader.verify_key != context.node.verify_key: return SyftError( message="Only the leader of the project can broadcast events" ) + if project_event.seq_no is None: + return SyftError(message=f"{project_event}.seq_no is None") if project_event.seq_no <= len(project.events) and len(project.events) > 0: return SyftNotReady(message="Project out of sync event") if project_event.seq_no > len(project.events) + 1: @@ -260,33 +269,31 @@ def broadcast_event( ) def sync( self, context: AuthedServiceContext, project_id: UID, seq_no: int - ) -> Union[SyftSuccess, SyftError, List[ProjectEvent]]: + ) -> Union[List[ProjectEvent], SyftError]: """To fetch unsynced events from the project""" + if context.node is None: + return SyftError(message=f"context {context}'s node is None") # Event object should be received from the leader of the project # retrieve the project object by node verify key project_obj = self.stash.get_by_uid(context.node.verify_key, uid=project_id) + if project_obj.is_err(): + return SyftError(message=str(project_obj.err())) - if project_obj.is_ok(): - project: Project = project_obj.ok() - if project.state_sync_leader.verify_key != context.node.verify_key: - return SyftError( - message="Project Events should be synced only with the leader" - ) - - if not project.has_permission(context.credentials): - return SyftError(message="User does not have permission to sync events") + project: Project = project_obj.ok() + if project.state_sync_leader.verify_key != context.node.verify_key: + return SyftError( + message="Project Events should be synced only with the leader" + ) - if seq_no < 0: - return SyftError( - message="Input seq_no should be a non negative integer" - ) + if not project.has_permission(context.credentials): + return SyftError(message="User does not have permission to sync events") - # retrieving unsycned events based on seq_no - return project.events[seq_no:] + if seq_no < 0: + return SyftError(message="Input seq_no should be a non negative integer") - if project_obj.is_err(): - return SyftError(message=str(project_obj.err())) + # retrieving unsycned events based on seq_no + return project.events[seq_no:] @service_method(path="project.get_all", name="get_all", roles=GUEST_ROLE_LEVEL) def get_all(self, context: AuthedServiceContext) -> Union[List[Project], SyftError]: @@ -327,7 +334,11 @@ def get_by_name( name="get_by_uid", roles=GUEST_ROLE_LEVEL, ) - def get_by_uid(self, context: AuthedServiceContext, uid: UID): + def get_by_uid( + self, context: AuthedServiceContext, uid: UID + ) -> Union[Project, SyftError]: + if context.node is None: + return SyftError(message=f"context {context}'s node is None") result = self.stash.get_by_uid( credentials=context.node.verify_key, uid=uid, @@ -343,7 +354,8 @@ def add_signing_key_to_project( ) -> Union[Project, SyftError]: # Automatically infuse signing key of user # requesting get_all() or creating the project object - + if context.node is None: + return SyftError(message=f"context {context}'s node is None") user_service = context.node.get_service("userservice") user = user_service.stash.get_by_verify_key( credentials=context.credentials, verify_key=context.credentials @@ -364,7 +376,7 @@ def check_for_project_request( project: Project, project_event: ProjectEvent, context: AuthedServiceContext, - ): + ) -> Union[SyftSuccess, SyftError]: """To check for project request event and create a message for the root user Args: @@ -375,6 +387,8 @@ def check_for_project_request( Returns: Union[SyftSuccess, SyftError]: SyftSuccess if message is created else SyftError """ + if context.node is None: + return SyftError(message=f"context {context}'s node is None") if ( isinstance(project_event, ProjectRequest) and project_event.linked_request.node_uid == context.node.id From de9fee33732c32bc699de4c691d09ba0cafa6734 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Thu, 15 Feb 2024 21:12:49 +0700 Subject: [PATCH 05/42] [refactor] done fixing mypy issues for `syft/util/` --- .pre-commit-config.yaml | 3 +- packages/syft/src/syft/util/autoreload.py | 6 +++ packages/syft/src/syft/util/schema.py | 16 ++++---- packages/syft/src/syft/util/telemetry.py | 5 +-- .../syft/src/syft/util/trace_decorator.py | 2 +- packages/syft/src/syft/util/util.py | 41 +++++++++++-------- .../syft/src/syft/util/version_compare.py | 2 +- 7 files changed, 44 insertions(+), 31 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6aea4ee653b..229fdc81a72 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -172,8 +172,7 @@ repos: - id: mypy name: "mypy: syft" always_run: true - files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util/env.py|^packages/syft/src/syft/util/logger.py|^packages/syft/src/syft/util/markdown.py|^packages/syft/src/syft/util/notebook_ui/notebook_addons.py|^packages/syft/src/syft/service" - #files: "^packages/syft/src/syft/serde" + files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util|^packages/syft/src/syft/service" args: [ "--follow-imports=skip", "--ignore-missing-imports", diff --git a/packages/syft/src/syft/util/autoreload.py b/packages/syft/src/syft/util/autoreload.py index 9c8a06c2740..e1f68e45555 100644 --- a/packages/syft/src/syft/util/autoreload.py +++ b/packages/syft/src/syft/util/autoreload.py @@ -4,6 +4,9 @@ def enable_autoreload() -> None: global AUTORELOAD_ENABLED try: + # third party + from IPython import get_ipython + ipython = get_ipython() # noqa: F821 ipython.run_line_magic("load_ext", "autoreload") ipython.run_line_magic("autoreload", "2") @@ -17,6 +20,9 @@ def enable_autoreload() -> None: def disable_autoreload() -> None: global AUTORELOAD_ENABLED try: + # third party + from IPython import get_ipython + ipython = get_ipython() # noqa: F821 ipython.run_line_magic("autoreload", "0") AUTORELOAD_ENABLED = False diff --git a/packages/syft/src/syft/util/schema.py b/packages/syft/src/syft/util/schema.py index b366f1bbd68..55edd83962a 100644 --- a/packages/syft/src/syft/util/schema.py +++ b/packages/syft/src/syft/util/schema.py @@ -28,8 +28,8 @@ } -def make_fake_type(_type_str: str): - jsonschema = {} +def make_fake_type(_type_str: str) -> dict[str, Any]: + jsonschema: dict = {} jsonschema["title"] = _type_str jsonschema["type"] = "object" jsonschema["properties"] = {} @@ -37,13 +37,13 @@ def make_fake_type(_type_str: str): return jsonschema -def get_type_mapping(_type) -> str: +def get_type_mapping(_type: Any) -> str: if _type in primitive_mapping: return primitive_mapping[_type] return _type.__name__ -def get_types(cls: Type, keys: List[str]) -> Dict[str, Type]: +def get_types(cls: Type, keys: List[str]) -> Optional[Dict[str, Type]]: types = [] for key in keys: _type = None @@ -61,8 +61,10 @@ def get_types(cls: Type, keys: List[str]) -> Dict[str, Type]: return dict(zip(keys, types)) -def convert_attribute_types(cls, attribute_list, attribute_types): - jsonschema = {} +def convert_attribute_types( + cls: Any, attribute_list: Any, attribute_types: Any +) -> dict[str, Any]: + jsonschema: dict = {} jsonschema["title"] = cls.__name__ jsonschema["type"] = "object" jsonschema["properties"] = {} @@ -79,7 +81,7 @@ def process_type_bank(type_bank: Dict[str, Tuple[Any, ...]]) -> Dict[str, Dict]: # first pass gets each type into basic json schema format json_mappings = {} count = 0 - converted_types = defaultdict(int) + converted_types: defaultdict = defaultdict(int) for k in type_bank: count += 1 t = type_bank[k] diff --git a/packages/syft/src/syft/util/telemetry.py b/packages/syft/src/syft/util/telemetry.py index 6d587d44408..15b0690d004 100644 --- a/packages/syft/src/syft/util/telemetry.py +++ b/packages/syft/src/syft/util/telemetry.py @@ -6,7 +6,6 @@ from typing import Union # third party -from typing_extensions import Concatenate from typing_extensions import ParamSpec @@ -25,7 +24,7 @@ def str_to_bool(bool_str: Optional[str]) -> bool: P = ParamSpec("P") -def setup_tracer() -> Callable[Concatenate[T, P], T]: +def setup_tracer() -> Callable[..., T]: def noop(func: T) -> T: return func @@ -85,4 +84,4 @@ def noop(func: T) -> T: return noop -instrument = setup_tracer() +instrument: Callable = setup_tracer() diff --git a/packages/syft/src/syft/util/trace_decorator.py b/packages/syft/src/syft/util/trace_decorator.py index c43749c1a8f..d30c6d7e9fc 100644 --- a/packages/syft/src/syft/util/trace_decorator.py +++ b/packages/syft/src/syft/util/trace_decorator.py @@ -119,7 +119,7 @@ def span_decorator(func_or_class: T) -> T: # We have already decorated this function, override return func_or_class - func_or_class.__tracing_unwrapped__ = func_or_class + func_or_class.__tracing_unwrapped__ = func_or_class # type: ignore tracer = existing_tracer or trace.get_tracer(func_or_class.__module__) diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index 59e0ed2d333..09f9f9a7e9c 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -79,14 +79,14 @@ def full_name_with_name(klass: type) -> str: raise e -def get_qualname_for(klass: type): +def get_qualname_for(klass: type) -> str: qualname = getattr(klass, "__qualname__", None) or getattr(klass, "__name__", None) if qualname is None: qualname = extract_name(klass) return qualname -def get_name_for(klass: type): +def get_name_for(klass: type) -> str: klass_name = getattr(klass, "__name__", None) if klass_name is None: klass_name = extract_name(klass) @@ -97,20 +97,23 @@ def get_mb_size(data: Any) -> float: return sys.getsizeof(data) / (1024 * 1024) -def extract_name(klass: type): +def extract_name(klass: type) -> str: name_regex = r".+class.+?([\w\._]+).+" regex2 = r"([\w\.]+)" matches = re.match(name_regex, str(klass)) if matches is None: matches = re.match(regex2, str(klass)) - try: - fqn = matches[1] - if "." in fqn: - return fqn.split(".")[-1] - return fqn - except Exception as e: - print(f"Failed to get klass name {klass}") - raise e + if matches: + try: + fqn = matches[1] + if "." in fqn: + return fqn.split(".")[-1] + return fqn + except Exception as e: + print(f"Failed to get klass name {klass}") + raise e + else: + raise ValueError(f"Failed to match regex for klass {klass}") def validate_type(_object: object, _type: type, optional: bool = False) -> Any: @@ -269,7 +272,7 @@ def print_process( # type: ignore refresh_rate=0.1, ) -> None: with lock: - while not finish.is_set(): # type: ignore + while not finish.is_set(): print(f"{bcolors.bold(message)} .", end="\r") time.sleep(refresh_rate) sys.stdout.flush() @@ -279,7 +282,7 @@ def print_process( # type: ignore print(f"{bcolors.bold(message)} ...", end="\r") time.sleep(refresh_rate) sys.stdout.flush() - if success.is_set(): # type: ignore + if success.is_set(): print(f"{bcolors.success(message)}" + (" " * len(message)), end="\n") else: print(f"{bcolors.failure(message)}" + (" " * len(message)), end="\n") @@ -445,7 +448,7 @@ def obj2pointer_type(obj: Optional[object] = None, fqn: Optional[str] = None) -> critical(log) raise Exception(log) - return ref.pointer_type # type: ignore + return ref.pointer_type def prompt_warning_message(message: str, confirm: bool = False) -> bool: @@ -661,7 +664,7 @@ def inherit_tags( ) -> None: tags = [] if self_obj is not None and hasattr(self_obj, "tags"): - tags.extend(list(self_obj.tags)) # type: ignore + tags.extend(list(self_obj.tags)) for arg in args: if hasattr(arg, "tags"): @@ -692,6 +695,7 @@ def autocache( return download_file(url, file_path) except Exception as e: print(f"Failed to autocache: {url}. {e}") + return None def str_to_bool(bool_str: Optional[str]) -> bool: @@ -856,6 +860,9 @@ def is_interpreter_standard() -> bool: def get_interpreter_module() -> str: try: + # third party + from IPython import get_ipython + shell = get_ipython().__class__.__module__ return shell except NameError: @@ -868,7 +875,7 @@ def get_interpreter_module() -> str: multiprocessing.set_start_method("spawn", True) -def thread_ident() -> int: +def thread_ident() -> Optional[int]: return threading.current_thread().ident @@ -876,7 +883,7 @@ def proc_id() -> int: return os.getpid() -def set_klass_module_to_syft(klass, module_name): +def set_klass_module_to_syft(klass: type, module_name: str) -> None: if module_name not in sys.modules["syft"].__dict__: new_module = types.ModuleType(module_name) else: diff --git a/packages/syft/src/syft/util/version_compare.py b/packages/syft/src/syft/util/version_compare.py index f9b6defdb65..606606bb6e8 100644 --- a/packages/syft/src/syft/util/version_compare.py +++ b/packages/syft/src/syft/util/version_compare.py @@ -33,7 +33,7 @@ def get_operator(version_string: str) -> Tuple[str, Callable, str]: return version_string, op, op_char -def check_rule(version_string: str, LATEST_STABLE_SYFT: str, __version__: str) -> bool: +def check_rule(version_string: str, LATEST_STABLE_SYFT: str, __version__: str) -> tuple: version_string, op, op_char = get_operator(version_string) syft_version = version.parse(__version__) stable_version = version.parse(LATEST_STABLE_SYFT) From 86ca8f1ccf42d992914fc562710da0d95d64f10b Mon Sep 17 00:00:00 2001 From: khoaguin Date: Fri, 16 Feb 2024 14:51:49 +0700 Subject: [PATCH 06/42] [refactor] done fixing mypy issues for `syft/service/action` --- .pre-commit-config.yaml | 3 +- .../syft/service/action/action_data_empty.py | 2 +- .../src/syft/service/action/action_graph.py | 68 +++++----- .../service/action/action_graph_service.py | 13 +- .../src/syft/service/action/action_object.py | 121 ++++++++++-------- .../syft/service/action/action_permissions.py | 5 +- .../src/syft/service/action/action_service.py | 53 +++++--- .../src/syft/service/action/action_store.py | 13 +- .../src/syft/service/action/action_types.py | 2 +- .../syft/src/syft/service/action/numpy.py | 22 ++-- .../syft/src/syft/service/action/pandas.py | 9 +- packages/syft/src/syft/service/action/plan.py | 10 +- 12 files changed, 192 insertions(+), 129 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 229fdc81a72..b4c616c2305 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -172,7 +172,8 @@ repos: - id: mypy name: "mypy: syft" always_run: true - files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util|^packages/syft/src/syft/service" + # files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util|^packages/syft/src/syft/service" + files: "^packages/syft/src/syft/service/action" args: [ "--follow-imports=skip", "--ignore-missing-imports", diff --git a/packages/syft/src/syft/service/action/action_data_empty.py b/packages/syft/src/syft/service/action/action_data_empty.py index b9c3960fad3..e32f4e339bb 100644 --- a/packages/syft/src/syft/service/action/action_data_empty.py +++ b/packages/syft/src/syft/service/action/action_data_empty.py @@ -19,7 +19,7 @@ class ActionDataEmpty(SyftObject): __canonical_name__ = "ActionDataEmpty" __version__ = SYFT_OBJECT_VERSION_1 - syft_internal_type: Optional[Type] = NoneType + syft_internal_type: Optional[Type] = NoneType # type: ignore def __repr__(self) -> str: return f"{type(self).__name__} <{self.syft_internal_type}>" diff --git a/packages/syft/src/syft/service/action/action_graph.py b/packages/syft/src/syft/service/action/action_graph.py index c3e25ac098b..f990e1bddfd 100644 --- a/packages/syft/src/syft/service/action/action_graph.py +++ b/packages/syft/src/syft/service/action/action_graph.py @@ -21,7 +21,6 @@ from result import Err from result import Ok from result import Result -from typing_extensions import Self # relative from ...node.credentials import SyftVerifyKey @@ -79,7 +78,7 @@ def make_created_at(cls, v: Optional[DateTime]) -> DateTime: return DateTime.now() if v is None else v @staticmethod - def from_action(action: Action, credentials: SyftVerifyKey): + def from_action(action: Action, credentials: SyftVerifyKey) -> "NodeActionData": is_mutagen = action.remote_self is not None and ( action.remote_self == action.result_id ) @@ -91,24 +90,26 @@ def from_action(action: Action, credentials: SyftVerifyKey): ) @staticmethod - def from_action_obj(action_obj: ActionObject, credentials: SyftVerifyKey): + def from_action_obj( + action_obj: ActionObject, credentials: SyftVerifyKey + ) -> "NodeActionData": return NodeActionData( id=action_obj.id, type=NodeType.ACTION_OBJECT, user_verify_key=credentials, ) - def __hash__(self): + def __hash__(self) -> int: return hash(self.id) - def __eq__(self, other: Self): + def __eq__(self, other: Any) -> bool: if not isinstance(other, NodeActionData): raise NotImplementedError( "Comparisions can be made with NodeActionData type objects only." ) return hash(self) == hash(other) - def __repr__(self): + def __repr__(self) -> str: return self._repr_debug_() @@ -148,7 +149,7 @@ def get(self, uid: Any) -> Any: def delete(self, uid: Any) -> None: raise NotImplementedError - def find_neighbors(self, uid: Any) -> List[Any]: + def find_neighbors(self, uid: Any) -> Optional[List[Any]]: raise NotImplementedError def update(self, uid: Any, data: Any) -> None: @@ -229,7 +230,9 @@ def lock(self) -> SyftLock: def db(self) -> nx.Graph: return self._db - def _thread_safe_cbk(self, cbk: Callable, *args, **kwargs): + def _thread_safe_cbk( + self, cbk: Callable, *args: Any, **kwargs: Any + ) -> Result[Any, str]: # TODO copied method from document_store, have it in one place and reuse? locked = self.lock.acquire(blocking=True) if not locked: @@ -267,10 +270,11 @@ def _delete(self, uid: UID) -> None: self.db.remove_node(uid) self.save() - def find_neighbors(self, uid: UID) -> Optional[Iterable]: + def find_neighbors(self, uid: UID) -> Optional[List[Any]]: if self.exists(uid=uid): neighbors = self.db.neighbors(uid) return neighbors + return None def update(self, uid: UID, data: Any) -> None: self._thread_safe_cbk(self._update, uid=uid, data=data) @@ -294,7 +298,7 @@ def _remove_edge(self, parent: Any, child: Any) -> None: self.db.remove_edge(parent, child) self.save() - def visualize(self, seed: int = 3113794652, figsize=(20, 10)) -> None: + def visualize(self, seed: int = 3113794652, figsize: tuple = (20, 10)) -> None: plt.figure(figsize=figsize) pos = nx.spring_layout(self.db, seed=seed) return nx.draw_networkx(self.db, pos=pos, with_labels=True) @@ -305,10 +309,10 @@ def nodes(self) -> Iterable: def edges(self) -> Iterable: return self.db.edges() - def get_predecessors(self, uid: UID) -> Iterable: + def get_predecessors(self, uid: UID) -> List: return self.db.predecessors(uid) - def get_successors(self, uid: UID) -> Iterable: + def get_successors(self, uid: UID) -> List: return self.db.successors(uid) def is_parent(self, parent: Any, child: Any) -> bool: @@ -372,10 +376,10 @@ def set( credentials: SyftVerifyKey, parent_uids: Optional[List[UID]] = None, ) -> Result[NodeActionData, str]: - if self.graph.exists(uid=node.id): + if self.graph.exists(uid=node.id): # type: ignore[call-arg] return Err(f"Node already exists in the graph: {node}") - self.graph.set(uid=node.id, data=node) + self.graph.set(uid=node.id, data=node) # type: ignore[call-arg] if parent_uids is None: parent_uids = [] @@ -397,8 +401,8 @@ def get( credentials: SyftVerifyKey, ) -> Result[NodeActionData, str]: # 🟡 TODO: Add permission check - if self.graph.exists(uid=uid): - node_data = self.graph.get(uid=uid) + if self.graph.exists(uid=uid): # type: ignore[call-arg] + node_data = self.graph.get(uid=uid) # type: ignore[call-arg] return Ok(node_data) return Err(f"Node does not exists with id: {uid}") @@ -408,8 +412,8 @@ def delete( credentials: SyftVerifyKey, ) -> Result[bool, str]: # 🟡 TODO: Add permission checks - if self.graph.exists(uid=uid): - self.graph.delete(uid=uid) + if self.graph.exists(uid=uid): # type: ignore[call-arg] + self.graph.delete(uid=uid) # type: ignore[call-arg] return Ok(True) return Err(f"Node does not exists with id: {uid}") @@ -420,11 +424,11 @@ def update( credentials: SyftVerifyKey, ) -> Result[NodeActionData, str]: # 🟡 TODO: Add permission checks - node_data = self.graph.get(uid=uid) + node_data = self.graph.get(uid=uid) # type: ignore[call-arg] if node_data is not None: for key, val in data.to_dict(exclude_empty=True).items(): setattr(node_data, key, val) - self.graph.update(uid=uid, data=node_data) + self.graph.update(uid=uid, data=node_data) # type: ignore[call-arg] return Ok(node_data) return Err(f"Node does not exists for uid: {uid}") @@ -438,7 +442,7 @@ def update_non_mutated_successor( Used when a node is a mutagen and to update non-mutated successor for all nodes between node_id and nm_successor_id """ - node_data = self.graph.get(uid=node_id) + node_data = self.graph.get(uid=node_id) # type: ignore[call-arg] data = NodeActionDataUpdate( next_mutagen_node=nm_successor_id, @@ -453,7 +457,7 @@ def update_non_mutated_successor( # loop through successive mutagen nodes and # update their last_nm_mutagen_node id while node_id != nm_successor_id: - node_data = self.graph.get(uid=node_id) + node_data = self.graph.get(uid=node_id) # type: ignore[call-arg] # If node is the last added mutagen node, # then in that case its `next_mutagen_node` will be None @@ -483,7 +487,7 @@ def update_non_mutated_successor( def _get_last_non_mutated_mutagen( self, credentials: SyftVerifyKey, uid: UID ) -> Result[UID, str]: - node_data = self.graph.get(uid=uid) + node_data = self.graph.get(uid=uid) # type: ignore[call-arg] if node_data.is_mutated: return Ok(node_data.last_nm_mutagen_node) @@ -495,10 +499,10 @@ def add_edge( child: UID, credentials: SyftVerifyKey, ) -> Result[bool, str]: - if not self.graph.exists(parent): + if not self.graph.exists(parent): # type: ignore[call-arg] return Err(f"Node does not exists for uid (parent): {parent}") - if not self.graph.exists(child): + if not self.graph.exists(child): # type: ignore[call-arg] return Err(f"Node does not exists for uid (child): {child}") result = self._get_last_non_mutated_mutagen( @@ -511,13 +515,13 @@ def add_edge( new_parent = result.ok() - self.graph.add_edge(parent=new_parent, child=child) + self.graph.add_edge(parent=new_parent, child=child) # type: ignore[call-arg] return Ok(True) def is_parent(self, parent: UID, child: UID) -> Result[bool, str]: - if self.graph.exists(child): - parents = self.graph.get_predecessors(child) + if self.graph.exists(child): # type: ignore[call-arg] + parents = self.graph.get_predecessors(child) # type: ignore[call-arg] result = parent in parents return Ok(result) return Err(f"Node doesn't exists for id: {child}") @@ -529,11 +533,11 @@ def query( ) -> Result[List[NodeActionData], str]: if isinstance(qks, QueryKey): qks = QueryKeys(qks=[qks]) - subgraph = self.graph.subgraph(qks=qks) - return Ok(self.graph.topological_sort(subgraph=subgraph)) + subgraph = self.graph.subgraph(qks=qks) # type: ignore[call-arg] + return Ok(self.graph.topological_sort(subgraph=subgraph)) # type: ignore[call-arg] def nodes(self, credentials: SyftVerifyKey) -> Result[List, str]: - return Ok(self.graph.nodes()) + return Ok(self.graph.nodes()) # type: ignore[call-arg] def edges(self, credentials: SyftVerifyKey) -> Result[List, str]: - return Ok(self.graph.edges()) + return Ok(self.graph.edges()) # type: ignore[call-arg] diff --git a/packages/syft/src/syft/service/action/action_graph_service.py b/packages/syft/src/syft/service/action/action_graph_service.py index 23c2b579401..df09578dcd5 100644 --- a/packages/syft/src/syft/service/action/action_graph_service.py +++ b/packages/syft/src/syft/service/action/action_graph_service.py @@ -1,6 +1,6 @@ # stdlib from typing import List -from typing import Tuple +from typing import Optional from typing import Union # third party @@ -60,6 +60,8 @@ def add_action( if action_node.is_mutagen: # updated non-mutated successor for all nodes between # node_id and nm_successor_id + if action.remote_self is None: + return SyftError(message=f"action {action}'s remote_self is None") result = self.store.update_non_mutated_successor( node_id=action.remote_self.id, nm_successor_id=action_node.id, @@ -102,7 +104,9 @@ def add_action_obj( return result.ok() - def _extract_input_and_output_from_action(self, action: Action) -> Tuple[UID]: + def _extract_input_and_output_from_action( + self, action: Action + ) -> tuple[set[UID], Optional[UID]]: input_uids = set() if action.remote_self is not None: @@ -114,7 +118,10 @@ def _extract_input_and_output_from_action(self, action: Action) -> Tuple[UID]: for _, kwarg in action.kwargs.items(): input_uids.add(kwarg.id) - output_uid = action.result_id.id + if action.result_id is not None: + output_uid = action.result_id.id + else: + output_uid = None return input_uids, output_uid diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index 5b21513f27b..df6e3d0578e 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -77,7 +77,7 @@ class ActionType(Enum): SYFTFUNCTION = 32 -def repr_cls(c): +def repr_cls(c: Any) -> str: return f"{c.__module__}.{c.__name__}" @@ -196,7 +196,9 @@ def syft_history_hash(self) -> int: return hashes @classmethod - def syft_function_action_from_kwargs_and_id(cls, kwargs, user_code_id): + def syft_function_action_from_kwargs_and_id( + cls, kwargs: Any, user_code_id: UID + ) -> Self: kwarg_ids = {} for k, v in kwargs.items(): kwarg_ids[k] = LineageID(v) @@ -230,8 +232,8 @@ def from_api_call(cls, api_call: SyftAPICall) -> Action: ) return action - def __repr__(self): - def repr_uid(_id): + def __repr__(self) -> str: + def repr_uid(_id: LineageID) -> str: return f"{str(_id)[:3]}..{str(_id)[-1]}" arg_repr = ", ".join([repr_uid(x) for x in self.args]) @@ -247,7 +249,7 @@ def repr_uid(_id): @migrate(Action, ActionV1) -def downgrade_action_v2_to_v1(): +def downgrade_action_v2_to_v1() -> list[Callable]: return [ drop("user_code_id"), make_set_default("op", ""), @@ -256,7 +258,7 @@ def downgrade_action_v2_to_v1(): @migrate(ActionV1, Action) -def upgrade_action_v1_to_v2(): +def upgrade_action_v1_to_v2() -> list[Callable]: return [make_set_default("user_code_id", None)] @@ -392,18 +394,19 @@ def make_action_side_effect( context.action = action except Exception as e: raise e - print(f"make_action_side_effect failed with {traceback.format_exc()}") - return Err(f"make_action_side_effect failed with {traceback.format_exc()}") + # print(f"make_action_side_effect failed with {traceback.format_exc()}") + # return Err(f"make_action_side_effect failed with {traceback.format_exc()}") + return Ok((context, args, kwargs)) class TraceResult: - result = [] + result: list = [] _client = None is_tracing = False @classmethod - def reset(cls): + def reset(cls) -> None: cls.result = [] cls._client = None @@ -431,7 +434,7 @@ def convert_to_pointers( if args is not None: for arg in args: if not isinstance(arg, (ActionObject, Asset, UID)): - arg = ActionObject.from_obj( + arg = ActionObject.from_obj( # type: ignore[unreachable] syft_action_data=arg, syft_client_verify_key=api.signing_key.verify_key, syft_node_location=api.node_uid, @@ -446,7 +449,7 @@ def convert_to_pointers( if kwargs is not None: for k, arg in kwargs.items(): if not isinstance(arg, (ActionObject, Asset, UID)): - arg = ActionObject.from_obj( + arg = ActionObject.from_obj( # type: ignore[unreachable] syft_action_data=arg, syft_client_verify_key=api.signing_key.verify_key, syft_node_location=api.node_uid, @@ -646,7 +649,7 @@ def syft_action_data(self) -> Any: return self.syft_action_data_cache - def reload_cache(self): + def reload_cache(self) -> None: # If ActionDataEmpty then try to fetch it from store. if isinstance(self.syft_action_data_cache, ActionDataEmpty): blob_storage_read_method = from_api_or_context( @@ -685,7 +688,7 @@ def reload_cache(self): else: print("cannot reload cache") - def _save_to_blob_storage_(self, data: Any) -> None: + def _save_to_blob_storage_(self, data: Any) -> Optional[SyftError]: # relative from ...types.blob_storage import BlobFile from ...types.blob_storage import CreateBlobStorageEntry @@ -737,6 +740,8 @@ def _save_to_blob_storage_(self, data: Any) -> None: self.syft_action_data_cache = data + return None + def _save_to_blob_storage(self) -> Optional[SyftError]: data = self.syft_action_data if isinstance(data, SyftError): @@ -748,6 +753,7 @@ def _save_to_blob_storage(self) -> Optional[SyftError]: return result if not TraceResult.is_tracing: self.syft_action_data_cache = self.as_empty_data() + return None @property def is_pointer(self) -> bool: @@ -775,25 +781,26 @@ def __check_action_data(cls, values: dict) -> dict: if inspect.isclass(v): values["syft_action_data_repr_"] = repr_cls(v) else: - values["syft_action_data_repr_"] = ( - v._repr_markdown_() - if hasattr(v, "_repr_markdown_") - else v.__repr__() - ) + if v is not None: + values["syft_action_data_repr_"] = ( + v._repr_markdown_() + if hasattr(v, "_repr_markdown_") + else v.__repr__() + ) values["syft_action_data_str_"] = str(v) values["syft_has_bool_attr"] = hasattr(v, "__bool__") return values @property - def is_mock(self): + def is_mock(self) -> bool: return self.syft_twin_type == TwinMode.MOCK @property - def is_real(self): + def is_real(self) -> bool: return self.syft_twin_type == TwinMode.PRIVATE @property - def is_twin(self): + def is_twin(self) -> bool: return self.syft_twin_type != TwinMode.NONE # @pydantic.validator("syft_action_data", pre=True, always=True) @@ -857,7 +864,7 @@ def syft_execute_action( ) return api.make_call(api_call) - def request(self, client): + def request(self, client: SyftClient) -> Union[Any, SyftError]: # relative from ..request.request import ActionStoreChange from ..request.request import SubmitRequest @@ -873,7 +880,7 @@ def request(self, client): ) return client.api.services.request.submit(submit_request) - def _syft_try_to_save_to_store(self, obj) -> None: + def _syft_try_to_save_to_store(self, obj: SyftObject) -> None: if self.syft_node_uid is None or self.syft_client_verify_key is None: return elif obj.syft_node_uid is not None: @@ -894,10 +901,10 @@ def _syft_try_to_save_to_store(self, obj) -> None: api = None if TraceResult._client is not None: - api = TraceResult._client.api + api = TraceResult._client.api # type: ignore[unreachable] if api is not None: - obj._set_obj_location_(api.node_uid, api.signing_key.verify_key) + obj._set_obj_location_(api.node_uid, api.signing_key.verify_key) # type: ignore[unreachable] res = obj._save_to_blob_storage() if isinstance(res, SyftError): print(f"failed saving {obj} to blob storage, error: {res}") @@ -914,7 +921,7 @@ def _syft_try_to_save_to_store(self, obj) -> None: ) if api is not None: - TraceResult.result += [action] + TraceResult.result += [action] # type: ignore[unreachable] else: api = APIRegistry.api_for( node_uid=self.syft_node_location, @@ -924,7 +931,7 @@ def _syft_try_to_save_to_store(self, obj) -> None: if isinstance(res, SyftError): print(f"Failed to to store (arg) {obj} to store, {res}") - def _syft_prepare_obj_uid(self, obj) -> LineageID: + def _syft_prepare_obj_uid(self, obj: Any) -> LineageID: # We got the UID if isinstance(obj, (UID, LineageID)): return LineageID(obj.id) @@ -1045,7 +1052,11 @@ def syft_make_action_with_self( def syft_get_path(self) -> str: """Get the type path of the underlying object""" - if isinstance(self, AnyActionObject) and self.syft_internal_type: + if ( + isinstance(self, AnyActionObject) + and self.syft_internal_type + and self.syft_action_data_type is not None + ): # avoids AnyActionObject errors return f"{self.syft_action_data_type.__name__}" return f"{type(self).__name__}" @@ -1112,7 +1123,7 @@ def get(self, block: bool = False) -> Any: nested_res.syft_client_verify_key = res.syft_client_verify_key return nested_res - def as_empty(self): + def as_empty(self) -> ActionObject: id = self.id # TODO: fix if isinstance(id, LineageID): @@ -1128,7 +1139,7 @@ def from_path( syft_lineage_id: Optional[LineageID] = None, syft_client_verify_key: Optional[SyftVerifyKey] = None, syft_node_location: Optional[UID] = None, - ): + ) -> ActionObject: """Create an Action Object from a file.""" # relative from ...types.blob_storage import BlobFile @@ -1204,20 +1215,20 @@ def from_obj( return action_object @classmethod - def add_trace_hook(cls): + def add_trace_hook(cls) -> bool: return True # if trace_action_side_effect not in self._syft_pre_hooks__[HOOK_ALWAYS]: # self._syft_pre_hooks__[HOOK_ALWAYS].append(trace_action_side_effect) @classmethod - def remove_trace_hook(cls): + def remove_trace_hook(cls) -> bool: return True # self._syft_pre_hooks__[HOOK_ALWAYS].pop(trace_action_side_effct, None) def as_empty_data(self) -> ActionDataEmpty: return ActionDataEmpty(syft_internal_type=self.syft_internal_type) - def wait(self): + def wait(self) -> ActionObject: # relative from ...client.api import APIRegistry @@ -1232,6 +1243,7 @@ def wait(self): while not api.services.action.is_resolved(obj_id): time.sleep(1) + return self @staticmethod @@ -1260,7 +1272,7 @@ def obj_not_ready( @staticmethod def empty( - syft_internal_type: Type[Any] = NoneType, + syft_internal_type: Optional[Type[Any]] = NoneType, # type: ignore[assignment] id: Optional[UID] = None, syft_lineage_id: Optional[LineageID] = None, syft_resolved: Optional[bool] = True, @@ -1399,7 +1411,7 @@ def _syft_output_action_object( constructor = action_type_for_type(result) syft_twin_type = TwinMode.NONE - if context.result_twin_type is not None: + if context is not None and context.result_twin_type is not None: syft_twin_type = context.result_twin_type result = constructor( syft_twin_type=syft_twin_type, @@ -1426,11 +1438,13 @@ def _syft_get_attr_context(self, name: str) -> Any: # use the custom defined version context_self = self if not defined_on_self: - context_self = self.syft_action_data # type: ignore + context_self = self.syft_action_data return context_self - def _syft_attr_propagate_ids(self, context, name: str, result: Any) -> Any: + def _syft_attr_propagate_ids( + self, context: PreHookContext, name: str, result: Any + ) -> Any: """Patch the results with the syft_history_hash, node_uid, and result_id.""" if name in self._syft_dont_wrap_attrs(): return result @@ -1574,7 +1588,7 @@ def _base_wrapper(*args: Any, **kwargs: Any) -> Any: if inspect.ismethod(original_func) or inspect.ismethoddescriptor(original_func): debug("Running method: ", name) - def wrapper(_self: Any, *args: Any, **kwargs: Any): + def wrapper(_self: Any, *args: Any, **kwargs: Any) -> Any: return _base_wrapper(*args, **kwargs) wrapper = types.MethodType(wrapper, type(self)) @@ -1599,9 +1613,9 @@ def wrapper(_self: Any, *args: Any, **kwargs: Any): # third party return wrapper - def _syft_setattr(self, name, value): + def _syft_setattr(self, name: str, value: Any) -> Any: args = (name, value) - kwargs = {} + kwargs: dict = {} op_name = "__setattr__" def fake_func(*args: Any, **kwargs: Any) -> Any: @@ -1685,7 +1699,7 @@ def __setattr__(self, name: str, value: Any) -> Any: return value else: self._syft_setattr(name, value) - context_self = self.syft_action_data # type: ignore + context_self = self.syft_action_data return context_self.__setattr__(name, value) # def keys(self) -> KeysView[str]: @@ -1710,11 +1724,12 @@ def _repr_markdown_(self) -> str: if inspect.isclass(self.syft_action_data_cache): data_repr_ = repr_cls(self.syft_action_data_cache) else: - data_repr_ = ( - self.syft_action_data_cache._repr_markdown_() - if hasattr(self.syft_action_data_cache, "_repr_markdown_") - else self.syft_action_data_cache.__repr__() - ) + if self.syft_action_data_cache is not None: + data_repr_ = ( + self.syft_action_data_cache._repr_markdown_() + if hasattr(self.syft_action_data_cache, "_repr_markdown_") + else self.syft_action_data_cache.__repr__() + ) return f"```python\n{res}\n```\n{data_repr_}" @@ -1842,10 +1857,10 @@ def __lshift__(self, other: Any) -> Any: def __rshift__(self, other: Any) -> Any: return self._syft_output_action_object(self.__rshift__(other)) - def __iter__(self): + def __iter__(self) -> Any: return self._syft_output_action_object(self.__iter__()) - def __next__(self): + def __next__(self) -> Any: return self._syft_output_action_object(self.__next__()) # r ops @@ -1891,14 +1906,14 @@ def __rrshift__(self, other: Any) -> Any: @migrate(ActionObject, ActionObjectV1) -def downgrade_actionobject_v2_to_v1(): +def downgrade_actionobject_v2_to_v1() -> list[Callable]: return [ drop("syft_resolved"), ] @migrate(ActionObjectV1, ActionObject) -def upgrade_actionobject_v1_to_v2(): +def upgrade_actionobject_v1_to_v2() -> list[Callable]: return [ make_set_default("syft_resolved", True), ] @@ -1932,7 +1947,7 @@ def __int__(self) -> float: @migrate(AnyActionObject, AnyActionObjectV1) -def downgrade_anyactionobject_v2_to_v1(): +def downgrade_anyactionobject_v2_to_v1() -> list[Callable]: return [ drop("syft_action_data_str"), drop("syft_resolved"), @@ -1940,7 +1955,7 @@ def downgrade_anyactionobject_v2_to_v1(): @migrate(AnyActionObjectV1, AnyActionObject) -def upgrade_anyactionobject_v1_to_v2(): +def upgrade_anyactionobject_v1_to_v2() -> list[Callable]: return [ make_set_default("syft_action_data_str", ""), make_set_default("syft_resolved", True), diff --git a/packages/syft/src/syft/service/action/action_permissions.py b/packages/syft/src/syft/service/action/action_permissions.py index e358de3be1a..d8620d0c396 100644 --- a/packages/syft/src/syft/service/action/action_permissions.py +++ b/packages/syft/src/syft/service/action/action_permissions.py @@ -46,7 +46,10 @@ def permission_string(self) -> str: if self.permission in COMPOUND_ACTION_PERMISSION: return f"{self.permission.name}" else: - return f"{self.credentials.verify}_{self.permission.name}" + if self.credentials is not None: + return f"{self.credentials.verify}_{self.permission.name}" + else: + return f"{self.permission.name}" def __repr__(self) -> str: if self.credentials is not None: diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 0c46b9871e9..e3db8f79fcf 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -4,6 +4,7 @@ from typing import Dict from typing import List from typing import Optional +from typing import Tuple from typing import Union # third party @@ -13,6 +14,7 @@ from result import Result # relative +from ...node.credentials import SyftVerifyKey from ...serde.serializable import serializable from ...types.datetime import DateTime from ...types.syft_object import SyftObject @@ -22,6 +24,7 @@ from ..code.user_code import UserCode from ..code.user_code import execute_byte_code from ..context import AuthedServiceContext +from ..policy.policy import OutputPolicy from ..policy.policy import retrieve_from_db from ..response import SyftError from ..response import SyftSuccess @@ -202,7 +205,7 @@ def _get( context: AuthedServiceContext, uid: UID, twin_mode: TwinMode = TwinMode.PRIVATE, - has_permission=False, + has_permission: bool = False, resolve_nested: bool = True, ) -> Result[ActionObject, str]: """Get an object from the action store""" @@ -382,7 +385,12 @@ def _user_code_execute( return Err(f"_user_code_execute failed. {e}") return Ok(result_action_object) - def set_result_to_store(self, result_action_object, context, output_policy=None): + def set_result_to_store( + self, + result_action_object: ActionObject, + context: AuthedServiceContext, + output_policy: Optional[OutputPolicy] = None, + ) -> Union[Result[ActionObject, str], SyftError]: result_id = result_action_object.id # result_blob_id = result_action_object.syft_blob_storage_entry_id @@ -423,10 +431,14 @@ def set_result_to_store(self, result_action_object, context, output_policy=None) BlobStorageService ) - def store_permission(x): + def store_permission( + x: Optional[SyftVerifyKey] = None, + ) -> ActionObjectPermission: return ActionObjectPermission(result_id, read_permission, x) - def blob_permission(x): + def blob_permission( + x: Optional[SyftVerifyKey] = None, + ) -> ActionObjectPermission: return ActionObjectPermission(result_blob_id, read_permission, x) if len(output_readers) > 0: @@ -439,8 +451,11 @@ def blob_permission(x): return set_result def execute_plan( - self, plan, context: AuthedServiceContext, plan_kwargs: Dict[str, ActionObject] - ): + self, + plan: Any, + context: AuthedServiceContext, + plan_kwargs: Dict[str, ActionObject], + ) -> Union[Result[ActionObject, str], SyftError]: id2inpkey = {v.id: k for k, v in plan.inputs.items()} for plan_action in plan.actions: @@ -466,7 +481,9 @@ def execute_plan( result_id = plan.outputs[0].id return self._get(context, result_id, TwinMode.MOCK, has_permission=True) - def call_function(self, context: AuthedServiceContext, action: Action): + def call_function( + self, context: AuthedServiceContext, action: Action + ) -> Union[Result[ActionObject, str], Err]: # run function/class init _user_lib_config_registry = UserLibConfigRegistry.from_user(context.credentials) absolute_path = f"{action.path}.{action.op}" @@ -484,7 +501,7 @@ def set_attribute( context: AuthedServiceContext, action: Action, resolved_self: Union[ActionObject, TwinObject], - ): + ) -> Result[Union[TwinObject, ActionObject], str]: args, _ = resolve_action_args(action, context, self) if args.is_err(): return Err( @@ -536,7 +553,7 @@ def set_attribute( def get_attribute( self, action: Action, resolved_self: Union[ActionObject, TwinObject] - ): + ) -> Ok[Union[TwinObject, ActionObject]]: if isinstance(resolved_self, TwinObject): private_result = getattr(resolved_self.private.syft_action_data, action.op) mock_result = getattr(resolved_self.mock.syft_action_data, action.op) @@ -558,7 +575,7 @@ def call_method( context: AuthedServiceContext, action: Action, resolved_self: Union[ActionObject, TwinObject], - ): + ) -> Result[Union[TwinObject, Any], str]: if isinstance(resolved_self, TwinObject): # method private_result = execute_object( @@ -612,7 +629,7 @@ def execute( for k, v in action.kwargs.items(): # transform lineage ids into ids kwarg_ids[k] = v.id - result_action_object: Result[ActionObject, Err] = usercode_service._call( + result_action_object = usercode_service._call( context, action.user_code_id, action.result_id, **kwarg_ids ) return result_action_object @@ -707,7 +724,7 @@ def exists( def resolve_action_args( action: Action, context: AuthedServiceContext, service: ActionService -): +) -> Tuple[Ok[Dict], bool]: has_twin_inputs = False args = [] for arg_id in action.args: @@ -724,7 +741,7 @@ def resolve_action_args( def resolve_action_kwargs( action: Action, context: AuthedServiceContext, service: ActionService -): +) -> Tuple[Ok[Dict], bool]: has_twin_inputs = False kwargs = {} for key, arg_id in action.kwargs.items(): @@ -761,7 +778,7 @@ def execute_callable( # stdlib # TODO: get from CMPTree is probably safer - def _get_target_callable(path: str, op: str): + def _get_target_callable(path: str, op: str) -> Any: path_elements = path.split(".") res = importlib.import_module(path_elements[0]) for p in path_elements[1:]: @@ -861,15 +878,15 @@ def execute_object( private_obj=result_action_object_private, mock_obj=result_action_object_mock, ) - elif twin_mode == twin_mode.PRIVATE: # type: ignore + elif twin_mode == twin_mode.PRIVATE: # type:ignore # twin private path - private_args = filter_twin_args(args, twin_mode=twin_mode) + private_args = filter_twin_args(args, twin_mode=twin_mode) # type:ignore[unreachable] private_kwargs = filter_twin_kwargs(kwargs, twin_mode=twin_mode) result = target_method(*private_args, **private_kwargs) result_action_object = wrap_result(action.result_id, result) - elif twin_mode == twin_mode.MOCK: # type: ignore + elif twin_mode == twin_mode.MOCK: # type:ignore # twin mock path - mock_args = filter_twin_args(args, twin_mode=twin_mode) + mock_args = filter_twin_args(args, twin_mode=twin_mode) # type:ignore[unreachable] mock_kwargs = filter_twin_kwargs(kwargs, twin_mode=twin_mode) target_method = getattr(unboxed_resolved_self, action.op, None) result = target_method(*mock_args, **mock_kwargs) diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index b939de6aada..d44e5181498 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -65,7 +65,7 @@ def __init__( self.root_verify_key = root_verify_key def get( - self, uid: UID, credentials: SyftVerifyKey, has_permission=False + self, uid: UID, credentials: SyftVerifyKey, has_permission: bool = False ) -> Result[SyftObject, str]: uid = uid.id # We only need the UID from LineageID or UID @@ -212,7 +212,10 @@ def has_permission(self, permission: ActionObjectPermission) -> bool: if not isinstance(permission.permission, ActionPermission): raise Exception(f"ObjectPermission type: {permission.permission} not valid") - if self.root_verify_key.verify == permission.credentials.verify: + if ( + permission.credentials is not None + and self.root_verify_key.verify == permission.credentials.verify + ): return True if ( @@ -241,7 +244,7 @@ def add_permission(self, permission: ActionObjectPermission) -> None: permissions.add(permission.permission_string) self.permissions[permission.uid] = permissions - def remove_permission(self, permission: ActionObjectPermission): + def remove_permission(self, permission: ActionObjectPermission) -> None: permissions = self.permissions[permission.uid] permissions.remove(permission.permission_string) self.permissions[permission.uid] = permissions @@ -250,7 +253,9 @@ def add_permissions(self, permissions: List[ActionObjectPermission]) -> None: for permission in permissions: self.add_permission(permission) - def migrate_data(self, to_klass: SyftObject, credentials: SyftVerifyKey): + def migrate_data( + self, to_klass: SyftObject, credentials: SyftVerifyKey + ) -> Result[bool, str]: has_root_permission = credentials == self.root_verify_key if has_root_permission: diff --git a/packages/syft/src/syft/service/action/action_types.py b/packages/syft/src/syft/service/action/action_types.py index bb9fa98504b..3fbe4b9c9f5 100644 --- a/packages/syft/src/syft/service/action/action_types.py +++ b/packages/syft/src/syft/service/action/action_types.py @@ -6,7 +6,7 @@ from ...util.logger import debug from .action_data_empty import ActionDataEmpty -action_types = {} +action_types: dict = {} def action_type_for_type(obj_or_type: Any) -> Type: diff --git a/packages/syft/src/syft/service/action/numpy.py b/packages/syft/src/syft/service/action/numpy.py index 45c778b58ab..dfd43907b92 100644 --- a/packages/syft/src/syft/service/action/numpy.py +++ b/packages/syft/src/syft/service/action/numpy.py @@ -1,10 +1,13 @@ # stdlib from typing import Any +from typing import Callable from typing import ClassVar from typing import Type +from typing import Union # third party import numpy as np +from typing_extensions import Self # relative from ...serde.serializable import serializable @@ -14,6 +17,7 @@ from ...types.transforms import drop from ...types.transforms import make_set_default from .action_object import ActionObject +from .action_object import ActionObjectPointer from .action_object import ActionObjectV1 from .action_object import BASE_PASSTHROUGH_ATTRS from .action_types import action_types @@ -28,7 +32,7 @@ # return domain_client.api.services.action.get(self.id).syft_action_data -class NumpyArrayObjectPointer: +class NumpyArrayObjectPointer(ActionObjectPointer): pass @@ -74,7 +78,9 @@ class NumpyArrayObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): # ) # return self == other - def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + def __array_ufunc__( + self, ufunc: Any, method: str, *inputs: Any, **kwargs: Any + ) -> Union[Self, tuple[Self, ...]]: inputs = tuple( np.array(x.syft_action_data, dtype=x.dtype) if isinstance(x, NumpyArrayObject) @@ -95,14 +101,14 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): @migrate(NumpyArrayObject, NumpyArrayObjectV1) -def downgrade_numpyarrayobject_v2_to_v1(): +def downgrade_numpyarrayobject_v2_to_v1() -> list[Callable]: return [ drop("syft_resolved"), ] @migrate(NumpyArrayObjectV1, NumpyArrayObject) -def upgrade_numpyarrayobject_v1_to_v2(): +def upgrade_numpyarrayobject_v1_to_v2() -> list[Callable]: return [ make_set_default("syft_resolved", True), ] @@ -132,14 +138,14 @@ def __float__(self) -> float: @migrate(NumpyScalarObject, NumpyScalarObjectV1) -def downgrade_numpyscalarobject_v2_to_v1(): +def downgrade_numpyscalarobject_v2_to_v1() -> list[Callable]: return [ drop("syft_resolved"), ] @migrate(NumpyScalarObjectV1, NumpyScalarObject) -def upgrade_numpyscalarobject_v1_to_v2(): +def upgrade_numpyscalarobject_v1_to_v2() -> list[Callable]: return [ make_set_default("syft_resolved", True), ] @@ -166,14 +172,14 @@ class NumpyBoolObject(ActionObject, np.lib.mixins.NDArrayOperatorsMixin): @migrate(NumpyBoolObject, NumpyBoolObjectV1) -def downgrade_numpyboolobject_v2_to_v1(): +def downgrade_numpyboolobject_v2_to_v1() -> list[Callable]: return [ drop("syft_resolved"), ] @migrate(NumpyBoolObjectV1, NumpyBoolObject) -def upgrade_numpyboolobject_v1_to_v2(): +def upgrade_numpyboolobject_v1_to_v2() -> list[Callable]: return [ make_set_default("syft_resolved", True), ] diff --git a/packages/syft/src/syft/service/action/pandas.py b/packages/syft/src/syft/service/action/pandas.py index a466545b363..2dac63f3b46 100644 --- a/packages/syft/src/syft/service/action/pandas.py +++ b/packages/syft/src/syft/service/action/pandas.py @@ -1,5 +1,6 @@ # stdlib from typing import Any +from typing import Callable from typing import ClassVar from typing import Type @@ -56,14 +57,14 @@ def syft_is_property(self, obj: Any, method: str) -> bool: @migrate(PandasDataFrameObject, PandasDataFrameObjectV1) -def downgrade_pandasdataframeobject_v2_to_v1(): +def downgrade_pandasdataframeobject_v2_to_v1() -> list[Callable]: return [ drop("syft_resolved"), ] @migrate(PandasDataFrameObjectV1, PandasDataFrameObject) -def upgrade_pandasdataframeobject_v1_to_v2(): +def upgrade_pandasdataframeobject_v1_to_v2() -> list[Callable]: return [ make_set_default("syft_resolved", True), ] @@ -105,14 +106,14 @@ def syft_is_property(self, obj: Any, method: str) -> bool: @migrate(PandasSeriesObject, PandasSeriesObjectV1) -def downgrade_pandasseriesframeobject_v2_to_v1(): +def downgrade_pandasseriesframeobject_v2_to_v1() -> list[Callable]: return [ drop("syft_resolved"), ] @migrate(PandasSeriesObjectV1, PandasSeriesObject) -def upgrade_pandasseriesframeobject_v1_to_v2(): +def upgrade_pandasseriesframeobject_v1_to_v2() -> list[Callable]: return [ make_set_default("syft_resolved", True), ] diff --git a/packages/syft/src/syft/service/action/plan.py b/packages/syft/src/syft/service/action/plan.py index 298f34693bc..a2a81b6f473 100644 --- a/packages/syft/src/syft/service/action/plan.py +++ b/packages/syft/src/syft/service/action/plan.py @@ -1,9 +1,11 @@ # stdlib import inspect +from typing import Any from typing import Callable from typing import Dict from typing import List from typing import Optional +from typing import Union # relative from ... import ActionObject @@ -45,17 +47,19 @@ def __repr__(self) -> str: return f"{obj_str}\n{inp_str}\n{act_str}\n{out_str}\n\n{plan_str}" - def remap_actions_to_inputs(self, **new_inputs): + def remap_actions_to_inputs(self, **new_inputs: Any) -> None: pass - def __call__(self, *args, **kwargs): + def __call__( + self, *args: Any, **kwargs: Any + ) -> Union[ActionObject, list[ActionObject]]: if len(self.outputs) == 1: return self.outputs[0] else: return self.outputs -def planify(func): +def planify(func: Callable) -> ActionObject: TraceResult.reset() ActionObject.add_trace_hook() TraceResult.is_tracing = True From 22b4c3946000b3be4973d19b8513b0b9a576ad43 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Fri, 16 Feb 2024 15:17:51 +0700 Subject: [PATCH 07/42] [tests] skip `test_transfer_request_nonblocking` on windows due to flakiness --- packages/syft/tests/syft/request/request_multiple_nodes_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/syft/tests/syft/request/request_multiple_nodes_test.py b/packages/syft/tests/syft/request/request_multiple_nodes_test.py index 9ec214ea7fe..502d2ea19d2 100644 --- a/packages/syft/tests/syft/request/request_multiple_nodes_test.py +++ b/packages/syft/tests/syft/request/request_multiple_nodes_test.py @@ -1,5 +1,6 @@ # stdlib import secrets +import sys from textwrap import dedent # third party @@ -149,6 +150,7 @@ def compute_sum(data) -> float: @pytest.mark.flaky(reruns=2, reruns_delay=1) +@pytest.mark.skipif(sys.platform == "win32", reason="very flaky on windows") def test_transfer_request_nonblocking( client_ds_1, client_do_1, client_do_2, dataset_1, dataset_2 ): From 2d32d24a7015f30180957f93dae30cbc432a4fb3 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 19 Feb 2024 08:37:24 +0700 Subject: [PATCH 08/42] [refactor] on fixing mypy issues for `service/queue/` - comment out "--warn-unused-ignores" in precommit mypy - small mypy fixes in `service/action` Co-authored-by: Kien Dang --- .pre-commit-config.yaml | 2 +- .../src/syft/service/action/action_graph.py | 19 ++++++++++--------- .../src/syft/service/action/action_service.py | 2 ++ .../syft/src/syft/service/queue/base_queue.py | 6 +++--- .../src/syft/service/queue/queue_stash.py | 10 ++++++---- .../syft/src/syft/service/queue/zmq_queue.py | 9 +++++---- .../request/request_multiple_nodes_test.py | 2 -- 7 files changed, 27 insertions(+), 23 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b4c616c2305..8eeb1168c06 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -180,7 +180,7 @@ repos: "--scripts-are-modules", "--disallow-incomplete-defs", "--no-implicit-optional", - "--warn-unused-ignores", + # "--warn-unused-ignores", "--warn-redundant-casts", "--strict-equality", "--warn-unreachable", diff --git a/packages/syft/src/syft/service/action/action_graph.py b/packages/syft/src/syft/service/action/action_graph.py index f990e1bddfd..a2423c0ed07 100644 --- a/packages/syft/src/syft/service/action/action_graph.py +++ b/packages/syft/src/syft/service/action/action_graph.py @@ -21,6 +21,7 @@ from result import Err from result import Ok from result import Result +from typing_extensions import Self # relative from ...node.credentials import SyftVerifyKey @@ -77,23 +78,23 @@ class NodeActionData(SyftObject): def make_created_at(cls, v: Optional[DateTime]) -> DateTime: return DateTime.now() if v is None else v - @staticmethod - def from_action(action: Action, credentials: SyftVerifyKey) -> "NodeActionData": + @classmethod + def from_action(cls, action: Action, credentials: SyftVerifyKey) -> Self: is_mutagen = action.remote_self is not None and ( action.remote_self == action.result_id ) - return NodeActionData( + return cls( id=action.id, type=NodeType.ACTION, user_verify_key=credentials, is_mutagen=is_mutagen, ) - @staticmethod + @classmethod def from_action_obj( - action_obj: ActionObject, credentials: SyftVerifyKey - ) -> "NodeActionData": - return NodeActionData( + cls, action_obj: ActionObject, credentials: SyftVerifyKey + ) -> Self: + return cls( id=action_obj.id, type=NodeType.ACTION_OBJECT, user_verify_key=credentials, @@ -149,7 +150,7 @@ def get(self, uid: Any) -> Any: def delete(self, uid: Any) -> None: raise NotImplementedError - def find_neighbors(self, uid: Any) -> Optional[List[Any]]: + def find_neighbors(self, uid: Any) -> Optional[List]: raise NotImplementedError def update(self, uid: Any, data: Any) -> None: @@ -270,7 +271,7 @@ def _delete(self, uid: UID) -> None: self.db.remove_node(uid) self.save() - def find_neighbors(self, uid: UID) -> Optional[List[Any]]: + def find_neighbors(self, uid: UID) -> Optional[List]: if self.exists(uid=uid): neighbors = self.db.neighbors(uid) return neighbors diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index e3db8f79fcf..7a34090f45b 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -60,6 +60,8 @@ def __init__(self, store: ActionStore) -> None: def np_array(self, context: AuthedServiceContext, data: Any) -> Any: if not isinstance(data, np.ndarray): data = np.array(data) + if context.node is None: + return SyftError(message=f"context {context}'s node is None") np_obj = NumpyArrayObject( dtype=data.dtype, shape=data.shape, diff --git a/packages/syft/src/syft/service/queue/base_queue.py b/packages/syft/src/syft/service/queue/base_queue.py index 5a5426242ca..65bb0002eb2 100644 --- a/packages/syft/src/syft/service/queue/base_queue.py +++ b/packages/syft/src/syft/service/queue/base_queue.py @@ -23,7 +23,7 @@ class AbstractMessageHandler: queue_name: ClassVar[str] @staticmethod - def handle_message(message: bytes, syft_worker_id: UID): + def handle_message(message: bytes, syft_worker_id: UID) -> None: raise NotImplementedError @@ -33,7 +33,7 @@ class QueueConsumer: queue_name: str address: str - def receive(self): + def receive(self) -> None: raise NotImplementedError def run(self) -> None: @@ -51,7 +51,7 @@ class QueueProducer: def send( self, message: Any, - ): + ) -> None: raise NotImplementedError def close(self) -> None: diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index 60e23644f78..e98641dc83a 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -1,9 +1,11 @@ # stdlib from enum import Enum from typing import Any +from typing import Callable from typing import Dict from typing import List from typing import Optional +from typing import Union # third party from result import Ok @@ -107,18 +109,18 @@ def _repr_markdown_(self) -> str: return f": {self.status}" @property - def is_action(self): + def is_action(self) -> bool: return self.service_path == "Action" and self.method_name == "execute" @property - def action(self): + def action(self) -> Union[Any, SyftError]: if self.is_action: return self.kwargs["action"] return SyftError(message="QueueItem not an Action") @migrate(QueueItem, QueueItemV1) -def downgrade_queueitem_v2_to_v1(): +def downgrade_queueitem_v2_to_v1() -> list[Callable]: return [ drop( [ @@ -135,7 +137,7 @@ def downgrade_queueitem_v2_to_v1(): @migrate(QueueItemV1, QueueItem) -def upgrade_queueitem_v1_to_v2(): +def upgrade_queueitem_v1_to_v2() -> list[Callable]: return [ make_set_default("method", ""), make_set_default("service", ""), diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index a6a346b6d11..1d7afdf4ce3 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -6,6 +6,7 @@ import threading import time from time import sleep +from typing import Any from typing import DefaultDict from typing import Dict from typing import List @@ -88,13 +89,13 @@ def __init__(self, offset_sec: float): self.reset() @property - def next_ts(self): + def next_ts(self) -> Union[float, int]: return self.__next_ts - def reset(self): + def reset(self) -> None: self.__next_ts = self.now() + self.__offset - def has_expired(self): + def has_expired(self) -> bool: return self.now() >= self.__next_ts @staticmethod @@ -110,7 +111,7 @@ class Worker(SyftBaseModel): expiry_t: Timeout = Timeout(WORKER_TIMEOUT_SEC) @validator("syft_worker_id", pre=True, always=True) - def set_syft_worker_id(cls, v, values): + def set_syft_worker_id(cls, v: Any, values: Any) -> Union[UID, Any]: if isinstance(v, str): return UID(v) return v diff --git a/packages/syft/tests/syft/request/request_multiple_nodes_test.py b/packages/syft/tests/syft/request/request_multiple_nodes_test.py index 502d2ea19d2..9ec214ea7fe 100644 --- a/packages/syft/tests/syft/request/request_multiple_nodes_test.py +++ b/packages/syft/tests/syft/request/request_multiple_nodes_test.py @@ -1,6 +1,5 @@ # stdlib import secrets -import sys from textwrap import dedent # third party @@ -150,7 +149,6 @@ def compute_sum(data) -> float: @pytest.mark.flaky(reruns=2, reruns_delay=1) -@pytest.mark.skipif(sys.platform == "win32", reason="very flaky on windows") def test_transfer_request_nonblocking( client_ds_1, client_do_1, client_do_2, dataset_1, dataset_2 ): From 70e21eb37cb83d03b68f540e994a5e724ffce10f Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 19 Feb 2024 09:32:31 +0700 Subject: [PATCH 09/42] [refactor] done fixing mypy issues for `base_queue.py` and `queue.py` --- .pre-commit-config.yaml | 2 +- .../syft/src/syft/service/queue/base_queue.py | 17 ++++-- packages/syft/src/syft/service/queue/queue.py | 56 +++++++++++-------- 3 files changed, 48 insertions(+), 27 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8eeb1168c06..4bdc6d1e740 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -173,7 +173,7 @@ repos: name: "mypy: syft" always_run: true # files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util|^packages/syft/src/syft/service" - files: "^packages/syft/src/syft/service/action" + files: "^packages/syft/src/syft/service/queue/base_queue.py" args: [ "--follow-imports=skip", "--ignore-missing-imports", diff --git a/packages/syft/src/syft/service/queue/base_queue.py b/packages/syft/src/syft/service/queue/base_queue.py index 65bb0002eb2..6e2d811fe22 100644 --- a/packages/syft/src/syft/service/queue/base_queue.py +++ b/packages/syft/src/syft/service/queue/base_queue.py @@ -7,6 +7,8 @@ # relative from ...serde.serializable import serializable +from ...service.context import AuthedServiceContext +from ...store.document_store import BaseStash from ...types.uid import UID from ..response import SyftError from ..response import SyftSuccess @@ -60,7 +62,8 @@ def close(self) -> None: @serializable() class QueueClient: - pass + def __init__(self, config: QueueClientConfig) -> None: + raise NotImplementedError @serializable() @@ -82,20 +85,26 @@ def __init__(self, config: QueueConfig): def post_init(self) -> None: pass - def close(self) -> None: + def close(self) -> Union[SyftError, SyftSuccess]: raise NotImplementedError def create_consumer( self, message_handler: Type[AbstractMessageHandler], - address: Optional[str], service_name: str, worker_stash: Optional[WorkerStash] = None, + address: Optional[str] = None, syft_worker_id: Optional[UID] = None, ) -> QueueConsumer: raise NotImplementedError - def create_producer(self, queue_name: str) -> QueueProducer: + def create_producer( + self, + queue_name: str, + queue_stash: Type[BaseStash], + context: AuthedServiceContext, + worker_stash: WorkerStash, + ) -> QueueProducer: raise NotImplementedError def send(self, message: bytes, queue_name: str) -> Union[SyftSuccess, SyftError]: diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index 1f81ed11bc3..2ceaefc6fb1 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -1,5 +1,4 @@ # stdlib -import multiprocessing import threading import time from typing import Any @@ -15,9 +14,12 @@ # relative from ...node.credentials import SyftVerifyKey +from ...node.worker import Worker +from ...node.worker_settings import WorkerSettings from ...serde.deserialize import _deserialize as deserialize from ...serde.serializable import serializable from ...service.context import AuthedServiceContext +from ...store.document_store import BaseStash from ...types.datetime import DateTime from ...types.uid import UID from ..job.job_stash import Job @@ -37,7 +39,11 @@ class MonitorThread(threading.Thread): def __init__( - self, queue_item: QueueItem, worker, credentials: SyftVerifyKey, interval=5 + self, + queue_item: QueueItem, + worker: Worker, + credentials: SyftVerifyKey, + interval: int = 5, ): super().__init__() self.interval = interval @@ -46,12 +52,12 @@ def __init__( self.worker = worker self.queue_item = queue_item - def run(self): + def run(self) -> None: while not self.stop_requested.is_set(): self.monitor() time.sleep(self.interval) - def monitor(self): + def monitor(self) -> None: # Implement the monitoring logic here job = self.worker.job_stash.get_by_uid( self.credentials, self.queue_item.job_id @@ -67,7 +73,7 @@ def monitor(self): process = psutil.Process(job.job_pid) process.terminate() - def stop(self): + def stop(self) -> None: self.stop_requested.set() @@ -75,11 +81,11 @@ def stop(self): class QueueManager(BaseQueueManager): config: QueueConfig - def post_init(self): + def post_init(self) -> None: self.client_config = self.config.client_config self._client = self.config.client_type(self.client_config) - def close(self): + def close(self) -> Union[SyftError, SyftSuccess]: return self._client.close() def create_consumer( @@ -103,7 +109,7 @@ def create_consumer( def create_producer( self, queue_name: str, - queue_stash, + queue_stash: Type[BaseStash], context: AuthedServiceContext, worker_stash: WorkerStash, ) -> QueueProducer: @@ -125,15 +131,19 @@ def send( ) @property - def producers(self): + def producers(self) -> Any: return self._client.producers @property - def consumers(self): + def consumers(self) -> Any: return self._client.consumers -def handle_message_multiprocessing(worker_settings, queue_item, credentials): +def handle_message_multiprocessing( + worker_settings: WorkerSettings, + queue_item: QueueItem, + credentials: SyftVerifyKey, +) -> None: # this is a temp hack to prevent some multithreading issues time.sleep(0.5) queue_config = worker_settings.queue_config @@ -258,7 +268,7 @@ class APICallMessageHandler(AbstractMessageHandler): queue_name = "api_call" @staticmethod - def handle_message(message: bytes, syft_worker_id: UID): + def handle_message(message: bytes, syft_worker_id: UID) -> None: # relative from ...node.node import Node @@ -312,29 +322,31 @@ def handle_message(message: bytes, syft_worker_id: UID): queue_result = worker.queue_stash.set_result(credentials, queue_item) if isinstance(queue_result, SyftError): - raise Exception(message=f"{queue_result.err()}") + raise Exception(f"{queue_result.err()}") job_result = worker.job_stash.set_result(credentials, job_item) if isinstance(job_result, SyftError): - raise Exception(message=f"{job_result.err()}") + raise Exception(f"{job_result.err()}") if queue_config.thread_workers: # stdlib from threading import Thread - p = Thread( + thread = Thread( target=handle_message_multiprocessing, args=(worker_settings, queue_item, credentials), ) - p.start() + thread.start() + thread.join() else: - p = multiprocessing.Process( + # stdlib + from multiprocessing import Process + + process = Process( target=handle_message_multiprocessing, args=(worker_settings, queue_item, credentials), ) - p.start() - - job_item.job_pid = p.pid + process.start() + job_item.job_pid = process.pid worker.job_stash.set_result(credentials, job_item) - - p.join() + process.join() From f40e407cbcb81aecfdccb882fd782d1d4c6261d5 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 19 Feb 2024 09:54:34 +0700 Subject: [PATCH 10/42] [refactor] fix cicular import error in queue.py --- .pre-commit-config.yaml | 2 +- .../32b1f4e9972b4f9583f99b67944a1623.json | 304 ++++++++++++++++++ .../8234fa441ef844f7bf2cb1c05cad9a7a.json | 304 ++++++++++++++++++ packages/syft/src/syft/service/queue/queue.py | 5 +- .../syft/src/syft/service/queue/zmq_queue.py | 5 +- 5 files changed, 615 insertions(+), 5 deletions(-) create mode 100644 packages/syft/src/syft/protocol/32b1f4e9972b4f9583f99b67944a1623.json create mode 100644 packages/syft/src/syft/protocol/8234fa441ef844f7bf2cb1c05cad9a7a.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4bdc6d1e740..cd54f42b1f6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -173,7 +173,7 @@ repos: name: "mypy: syft" always_run: true # files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util|^packages/syft/src/syft/service" - files: "^packages/syft/src/syft/service/queue/base_queue.py" + files: "^packages/syft/src/syft/service/queue/queue.py|^packages/syft/src/syft/service/queue/base_queue.py" args: [ "--follow-imports=skip", "--ignore-missing-imports", diff --git a/packages/syft/src/syft/protocol/32b1f4e9972b4f9583f99b67944a1623.json b/packages/syft/src/syft/protocol/32b1f4e9972b4f9583f99b67944a1623.json new file mode 100644 index 00000000000..86f1e20ca5e --- /dev/null +++ b/packages/syft/src/syft/protocol/32b1f4e9972b4f9583f99b67944a1623.json @@ -0,0 +1,304 @@ +{ + "1": { + "release_name": "0.8.2.json" + }, + "2": { + "release_name": "0.8.3.json" + }, + "dev": { + "object_versions": { + "SyftWorkerImage": { + "1": { + "version": 1, + "hash": "2a9585b6a286e24f1a9f3f943d0128730cf853edc549184dc1809d19e1eec54b", + "action": "add" + } + }, + "ActionDataLink": { + "1": { + "version": 1, + "hash": "10bf94e99637695f1ba283f0b10e70743a4ebcb9ee75aefb1a05e6d6e1d21a71", + "action": "add" + } + }, + "ObjectNotReady": { + "1": { + "version": 1, + "hash": "88207988639b11eaca686b6e079616d9caecc3dbc2a8112258e0f39ee5c3e113", + "action": "add" + } + }, + "JobItem": { + "3": { + "version": 3, + "hash": "5b93a59e28574691339d22826d5650969336a2e930b93d6b3fe6d5409ca0cfc4", + "action": "add" + } + }, + "SeaweedSecureFilePathLocation": { + "2": { + "version": 2, + "hash": "5fd63fed2a4efba8c2b6c7a7b5e9b5939181781c331230896aa130b6fd558739", + "action": "add" + } + }, + "AzureSecureFilePathLocation": { + "1": { + "version": 1, + "hash": "1bb15f3f9d7082779f1c9f58de94011487924cb8a8c9c2ec18fd7c161c27fd0e", + "action": "add" + } + }, + "RemoteConfig": { + "1": { + "version": 1, + "hash": "ad7bc4780a8ad52e14ce68601852c93d2fe07bda489809cad7cae786d2461754", + "action": "add" + } + }, + "AzureRemoteConfig": { + "1": { + "version": 1, + "hash": "c05c6caa27db4e385c642536d4b0ecabc0c71e91220d2e6ce21a2761ca68a673", + "action": "add" + } + }, + "BlobRetrievalByURL": { + "2": { + "version": 2, + "hash": "8059ee03016c4d74e408dad9529e877f91829672e0cc42d8cfff9c8e14058adc", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "0b664100ea08413ca4ef04665ca910c2cf9535539617ea4ba33687d05cdfe747", + "action": "add" + } + }, + "QueueItem": { + "3": { + "version": 3, + "hash": "3495f406d2c97050ce86be80c230f49b6b846c63b9a9230cbd6631952f2bad0f", + "action": "add" + } + }, + "ActionQueueItem": { + "2": { + "version": 2, + "hash": "6413ed01e949cac169299a43ce40651f9bf8053e408b6942853f8afa8a693b3d", + "action": "add" + } + }, + "ZMQClientConfig": { + "2": { + "version": 2, + "hash": "0f9bc88d56cd6eed6fc75459d1f914aed840c66e1195b9e41cc501b488fef2ed", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "91ce5953cced58e12c576aa5174d5ca0c91981b01cf42edd5283d347baa3390b", + "action": "add" + } + }, + "SyftWorker": { + "1": { + "version": 1, + "hash": "0d5b367162f3ce55ab090cc1b49bd30e50d4eb144e8431eadc679bd0e743aa70", + "action": "add" + } + }, + "WorkerPool": { + "1": { + "version": 1, + "hash": "250699eb4c452fc427995353d5c5ad6245fb3e9fdac8814f8348784816a0733b", + "action": "add" + } + }, + "SyftImageRegistry": { + "1": { + "version": 1, + "hash": "dc83910c91947e3d9eaa3e6f8592237448f0408668c7cca80450b5fcd54722e1", + "action": "add" + } + }, + "UserCode": { + "3": { + "version": 3, + "hash": "90fcae0f556f375ba1e91d2e345f57241660695c6e2b84c8e311df89d09e6c66", + "action": "add" + } + }, + "SubmitUserCode": { + "3": { + "version": 3, + "hash": "a29160c16d2e2620800d42cdcd9f3637d063a570c477a5d05217a2e64b4bb396", + "action": "add" + } + }, + "CreateCustomImageChange": { + "1": { + "version": 1, + "hash": "bc09dca7995938f3b3a2bd9c8b3c2feffc8484df466144a425cb69cadb2ab635", + "action": "add" + } + }, + "CreateCustomWorkerPoolChange": { + "1": { + "version": 1, + "hash": "86894f8ccc037de61f44f9698fd113ba02c3cf3870a3048c00a46e15dcd1941c", + "action": "add" + } + }, + "JobInfo": { + "1": { + "version": 1, + "hash": "cf26eeac3d9254dfa439917493b816341f8a379a77d182bbecba3b7ed2c1d00a", + "action": "add" + } + }, + "User": { + "1": { + "version": 1, + "hash": "078636e64f737e60245b39cf348d30fb006531e80c12b70aa7cf98254e1bb37a", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "ded970c92f202716ed33a2117cf541789f35fad66bd4b1db39da5026b1d7d0e7", + "action": "add" + } + }, + "UserUpdate": { + "1": { + "version": 1, + "hash": "839dd90aeb611e1dc471c8fd6daf230e913465c0625c6a297079cb7f0a271195", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "32cba8fbd786c575f92e26c31384d282e68e3ebfe5c4b0a0e793820b1228d246", + "action": "add" + } + }, + "UserCreate": { + "1": { + "version": 1, + "hash": "dab78b63544ae91c09f9843c323cb237c0a6fcfeb71c1acf5f738e2fcf5c277f", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "2540188c5aaea866914dccff459df6e0f4727108a503414bb1567ff6297d4646", + "action": "add" + } + }, + "UserView": { + "1": { + "version": 1, + "hash": "63289383fe7e7584652f242a4362ce6e2f0ade52f6416ab6149b326a506b0675", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "e410de583bb15bc5af57acef7be55ea5fc56b5b0fc169daa3869f4203c4d7473", + "action": "add" + } + }, + "BlobFile": { + "2": { + "version": 2, + "hash": "f2b29d28fe81a04bf5e946c819010283a9f98a97d50519358bead773865a2e09", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "8f1710c754bb3b39f546b97fd69c4826291398b247976bbc41fa873af431bca9", + "action": "add" + } + }, + "SyftObjectRetrieval": { + "1": { + "version": 1, + "hash": "7ccc62d5b434d2d438b3df661b4d753b0c7c8d593d451d8b86d364da83998c89", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "952958e9afae007bef3cb89aa15be95dddc4c310e3a8ce4191576f90ac6fcbc8", + "action": "add" + } + }, + "ActionFileData": { + "1": { + "version": 1, + "hash": "1f32d94b75b0a6b4e86cec93d94aa905738219e3e7e75f51dd335ee832a6ed3e", + "action": "remove" + } + }, + "SeaweedFSBlobDeposit": { + "2": { + "version": 2, + "hash": "07d84a95324d95d9c868cd7d1c33c908f77aa468671d76c144586aab672bcbb5", + "action": "add" + } + }, + "mock_type": { + "1": { + "version": 1, + "hash": "438b791584a3a832288d9b3038a43a753ee9febd065deff45ae2194a33ba4513", + "action": "add" + } + }, + "71beaeec12414862889743793403de2e": { + "1": { + "version": 1, + "hash": "59c901096d894c20f87a499eef6d2b7e579263f63bafcdaa249459e6a672340f", + "action": "add" + } + }, + "MockStoreConfig": { + "1": { + "version": 1, + "hash": "77750b74b93749a150a3632cf76c8052dcaa246e68c54b54fac3e14c7d78b3d3", + "action": "add" + } + }, + "NodeActionData": { + "1": { + "version": 1, + "hash": "0635d37c81929469f799e1f0889a2e498818ddfed82441795dc6649ff51ccd3b", + "action": "add" + } + }, + "NodeActionDataUpdate": { + "1": { + "version": 1, + "hash": "f0305c6ae72a464f7779fa17a79d405ae89685fc4d38330d4dfded7a572a6e3a", + "action": "add" + } + }, + "InMemoryGraphConfig": { + "1": { + "version": 1, + "hash": "efb9a7a9b8468db54f616dec62bcdcaf8e93637bc0490759f27bcb9ae746c075", + "action": "add" + } + }, + "MockWrapper": { + "1": { + "version": 1, + "hash": "67e74d0ff1db448b75b2ed73ee1e162200f972c8d7c29eb66fcd3095671181e8", + "action": "add" + } + }, + "base_stash_mock_object_type": { + "1": { + "version": 1, + "hash": "5a44200e21f7ce85346a92413f01944fdf0a54e60729244d23f356aa56414968", + "action": "add" + } + } + } + } +} diff --git a/packages/syft/src/syft/protocol/8234fa441ef844f7bf2cb1c05cad9a7a.json b/packages/syft/src/syft/protocol/8234fa441ef844f7bf2cb1c05cad9a7a.json new file mode 100644 index 00000000000..a52fbae529c --- /dev/null +++ b/packages/syft/src/syft/protocol/8234fa441ef844f7bf2cb1c05cad9a7a.json @@ -0,0 +1,304 @@ +{ + "1": { + "release_name": "0.8.2.json" + }, + "2": { + "release_name": "0.8.3.json" + }, + "dev": { + "object_versions": { + "SyftWorkerImage": { + "1": { + "version": 1, + "hash": "2a9585b6a286e24f1a9f3f943d0128730cf853edc549184dc1809d19e1eec54b", + "action": "add" + } + }, + "ActionDataLink": { + "1": { + "version": 1, + "hash": "10bf94e99637695f1ba283f0b10e70743a4ebcb9ee75aefb1a05e6d6e1d21a71", + "action": "add" + } + }, + "ObjectNotReady": { + "1": { + "version": 1, + "hash": "88207988639b11eaca686b6e079616d9caecc3dbc2a8112258e0f39ee5c3e113", + "action": "add" + } + }, + "JobItem": { + "3": { + "version": 3, + "hash": "5b93a59e28574691339d22826d5650969336a2e930b93d6b3fe6d5409ca0cfc4", + "action": "add" + } + }, + "SeaweedSecureFilePathLocation": { + "2": { + "version": 2, + "hash": "5fd63fed2a4efba8c2b6c7a7b5e9b5939181781c331230896aa130b6fd558739", + "action": "add" + } + }, + "AzureSecureFilePathLocation": { + "1": { + "version": 1, + "hash": "1bb15f3f9d7082779f1c9f58de94011487924cb8a8c9c2ec18fd7c161c27fd0e", + "action": "add" + } + }, + "RemoteConfig": { + "1": { + "version": 1, + "hash": "ad7bc4780a8ad52e14ce68601852c93d2fe07bda489809cad7cae786d2461754", + "action": "add" + } + }, + "AzureRemoteConfig": { + "1": { + "version": 1, + "hash": "c05c6caa27db4e385c642536d4b0ecabc0c71e91220d2e6ce21a2761ca68a673", + "action": "add" + } + }, + "BlobRetrievalByURL": { + "2": { + "version": 2, + "hash": "8059ee03016c4d74e408dad9529e877f91829672e0cc42d8cfff9c8e14058adc", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "0b664100ea08413ca4ef04665ca910c2cf9535539617ea4ba33687d05cdfe747", + "action": "add" + } + }, + "QueueItem": { + "3": { + "version": 3, + "hash": "3495f406d2c97050ce86be80c230f49b6b846c63b9a9230cbd6631952f2bad0f", + "action": "add" + } + }, + "ActionQueueItem": { + "2": { + "version": 2, + "hash": "6413ed01e949cac169299a43ce40651f9bf8053e408b6942853f8afa8a693b3d", + "action": "add" + } + }, + "ZMQClientConfig": { + "2": { + "version": 2, + "hash": "0f9bc88d56cd6eed6fc75459d1f914aed840c66e1195b9e41cc501b488fef2ed", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "91ce5953cced58e12c576aa5174d5ca0c91981b01cf42edd5283d347baa3390b", + "action": "add" + } + }, + "SyftWorker": { + "1": { + "version": 1, + "hash": "0d5b367162f3ce55ab090cc1b49bd30e50d4eb144e8431eadc679bd0e743aa70", + "action": "add" + } + }, + "WorkerPool": { + "1": { + "version": 1, + "hash": "250699eb4c452fc427995353d5c5ad6245fb3e9fdac8814f8348784816a0733b", + "action": "add" + } + }, + "SyftImageRegistry": { + "1": { + "version": 1, + "hash": "dc83910c91947e3d9eaa3e6f8592237448f0408668c7cca80450b5fcd54722e1", + "action": "add" + } + }, + "UserCode": { + "3": { + "version": 3, + "hash": "90fcae0f556f375ba1e91d2e345f57241660695c6e2b84c8e311df89d09e6c66", + "action": "add" + } + }, + "SubmitUserCode": { + "3": { + "version": 3, + "hash": "a29160c16d2e2620800d42cdcd9f3637d063a570c477a5d05217a2e64b4bb396", + "action": "add" + } + }, + "CreateCustomImageChange": { + "1": { + "version": 1, + "hash": "bc09dca7995938f3b3a2bd9c8b3c2feffc8484df466144a425cb69cadb2ab635", + "action": "add" + } + }, + "CreateCustomWorkerPoolChange": { + "1": { + "version": 1, + "hash": "86894f8ccc037de61f44f9698fd113ba02c3cf3870a3048c00a46e15dcd1941c", + "action": "add" + } + }, + "JobInfo": { + "1": { + "version": 1, + "hash": "cf26eeac3d9254dfa439917493b816341f8a379a77d182bbecba3b7ed2c1d00a", + "action": "add" + } + }, + "User": { + "1": { + "version": 1, + "hash": "078636e64f737e60245b39cf348d30fb006531e80c12b70aa7cf98254e1bb37a", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "ded970c92f202716ed33a2117cf541789f35fad66bd4b1db39da5026b1d7d0e7", + "action": "add" + } + }, + "UserUpdate": { + "1": { + "version": 1, + "hash": "839dd90aeb611e1dc471c8fd6daf230e913465c0625c6a297079cb7f0a271195", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "32cba8fbd786c575f92e26c31384d282e68e3ebfe5c4b0a0e793820b1228d246", + "action": "add" + } + }, + "UserCreate": { + "1": { + "version": 1, + "hash": "dab78b63544ae91c09f9843c323cb237c0a6fcfeb71c1acf5f738e2fcf5c277f", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "2540188c5aaea866914dccff459df6e0f4727108a503414bb1567ff6297d4646", + "action": "add" + } + }, + "UserView": { + "1": { + "version": 1, + "hash": "63289383fe7e7584652f242a4362ce6e2f0ade52f6416ab6149b326a506b0675", + "action": "remove" + }, + "2": { + "version": 2, + "hash": "e410de583bb15bc5af57acef7be55ea5fc56b5b0fc169daa3869f4203c4d7473", + "action": "add" + } + }, + "BlobFile": { + "2": { + "version": 2, + "hash": "f2b29d28fe81a04bf5e946c819010283a9f98a97d50519358bead773865a2e09", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "8f1710c754bb3b39f546b97fd69c4826291398b247976bbc41fa873af431bca9", + "action": "add" + } + }, + "SyftObjectRetrieval": { + "1": { + "version": 1, + "hash": "7ccc62d5b434d2d438b3df661b4d753b0c7c8d593d451d8b86d364da83998c89", + "action": "remove" + }, + "3": { + "version": 3, + "hash": "952958e9afae007bef3cb89aa15be95dddc4c310e3a8ce4191576f90ac6fcbc8", + "action": "add" + } + }, + "ActionFileData": { + "1": { + "version": 1, + "hash": "1f32d94b75b0a6b4e86cec93d94aa905738219e3e7e75f51dd335ee832a6ed3e", + "action": "remove" + } + }, + "SeaweedFSBlobDeposit": { + "2": { + "version": 2, + "hash": "07d84a95324d95d9c868cd7d1c33c908f77aa468671d76c144586aab672bcbb5", + "action": "add" + } + }, + "mock_type": { + "1": { + "version": 1, + "hash": "438b791584a3a832288d9b3038a43a753ee9febd065deff45ae2194a33ba4513", + "action": "add" + } + }, + "0c4ba3ae77614f5597dd7e162eef2f7d": { + "1": { + "version": 1, + "hash": "b8618818d6b35d9734fd91d701d98757d3bfd0b261fb6946e350822f91c00d27", + "action": "add" + } + }, + "MockStoreConfig": { + "1": { + "version": 1, + "hash": "77750b74b93749a150a3632cf76c8052dcaa246e68c54b54fac3e14c7d78b3d3", + "action": "add" + } + }, + "NodeActionData": { + "1": { + "version": 1, + "hash": "0635d37c81929469f799e1f0889a2e498818ddfed82441795dc6649ff51ccd3b", + "action": "add" + } + }, + "NodeActionDataUpdate": { + "1": { + "version": 1, + "hash": "f0305c6ae72a464f7779fa17a79d405ae89685fc4d38330d4dfded7a572a6e3a", + "action": "add" + } + }, + "InMemoryGraphConfig": { + "1": { + "version": 1, + "hash": "efb9a7a9b8468db54f616dec62bcdcaf8e93637bc0490759f27bcb9ae746c075", + "action": "add" + } + }, + "MockWrapper": { + "1": { + "version": 1, + "hash": "67e74d0ff1db448b75b2ed73ee1e162200f972c8d7c29eb66fcd3095671181e8", + "action": "add" + } + }, + "base_stash_mock_object_type": { + "1": { + "version": 1, + "hash": "5a44200e21f7ce85346a92413f01944fdf0a54e60729244d23f356aa56414968", + "action": "add" + } + } + } + } +} diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index 2ceaefc6fb1..b72fd11b2ac 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -14,7 +14,6 @@ # relative from ...node.credentials import SyftVerifyKey -from ...node.worker import Worker from ...node.worker_settings import WorkerSettings from ...serde.deserialize import _deserialize as deserialize from ...serde.serializable import serializable @@ -41,10 +40,10 @@ class MonitorThread(threading.Thread): def __init__( self, queue_item: QueueItem, - worker: Worker, + worker: Any, # should be of type Worker(Node), but get circular import error credentials: SyftVerifyKey, interval: int = 5, - ): + ) -> None: super().__init__() self.interval = interval self.stop_requested = threading.Event() diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index 1d7afdf4ce3..ac6af365f96 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -968,7 +968,10 @@ def purge_all(self) -> Union[SyftError, SyftSuccess]: @serializable() class ZMQQueueConfig(QueueConfig): def __init__( - self, client_type=None, client_config=None, thread_workers: bool = False + self, + client_type: Optional[Union[ZMQClient, Any]] = None, + client_config: Optional[ZMQClientConfig] = None, + thread_workers: bool = False, ): self.client_type = client_type or ZMQClient self.client_config: ZMQClientConfig = client_config or ZMQClientConfig() From 6ce490e5539e0cb99c65363471aa43dc715d3fdd Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 19 Feb 2024 09:56:08 +0700 Subject: [PATCH 11/42] delete unnecessary protocol json files --- .../32b1f4e9972b4f9583f99b67944a1623.json | 304 ------------------ .../8234fa441ef844f7bf2cb1c05cad9a7a.json | 304 ------------------ 2 files changed, 608 deletions(-) delete mode 100644 packages/syft/src/syft/protocol/32b1f4e9972b4f9583f99b67944a1623.json delete mode 100644 packages/syft/src/syft/protocol/8234fa441ef844f7bf2cb1c05cad9a7a.json diff --git a/packages/syft/src/syft/protocol/32b1f4e9972b4f9583f99b67944a1623.json b/packages/syft/src/syft/protocol/32b1f4e9972b4f9583f99b67944a1623.json deleted file mode 100644 index 86f1e20ca5e..00000000000 --- a/packages/syft/src/syft/protocol/32b1f4e9972b4f9583f99b67944a1623.json +++ /dev/null @@ -1,304 +0,0 @@ -{ - "1": { - "release_name": "0.8.2.json" - }, - "2": { - "release_name": "0.8.3.json" - }, - "dev": { - "object_versions": { - "SyftWorkerImage": { - "1": { - "version": 1, - "hash": "2a9585b6a286e24f1a9f3f943d0128730cf853edc549184dc1809d19e1eec54b", - "action": "add" - } - }, - "ActionDataLink": { - "1": { - "version": 1, - "hash": "10bf94e99637695f1ba283f0b10e70743a4ebcb9ee75aefb1a05e6d6e1d21a71", - "action": "add" - } - }, - "ObjectNotReady": { - "1": { - "version": 1, - "hash": "88207988639b11eaca686b6e079616d9caecc3dbc2a8112258e0f39ee5c3e113", - "action": "add" - } - }, - "JobItem": { - "3": { - "version": 3, - "hash": "5b93a59e28574691339d22826d5650969336a2e930b93d6b3fe6d5409ca0cfc4", - "action": "add" - } - }, - "SeaweedSecureFilePathLocation": { - "2": { - "version": 2, - "hash": "5fd63fed2a4efba8c2b6c7a7b5e9b5939181781c331230896aa130b6fd558739", - "action": "add" - } - }, - "AzureSecureFilePathLocation": { - "1": { - "version": 1, - "hash": "1bb15f3f9d7082779f1c9f58de94011487924cb8a8c9c2ec18fd7c161c27fd0e", - "action": "add" - } - }, - "RemoteConfig": { - "1": { - "version": 1, - "hash": "ad7bc4780a8ad52e14ce68601852c93d2fe07bda489809cad7cae786d2461754", - "action": "add" - } - }, - "AzureRemoteConfig": { - "1": { - "version": 1, - "hash": "c05c6caa27db4e385c642536d4b0ecabc0c71e91220d2e6ce21a2761ca68a673", - "action": "add" - } - }, - "BlobRetrievalByURL": { - "2": { - "version": 2, - "hash": "8059ee03016c4d74e408dad9529e877f91829672e0cc42d8cfff9c8e14058adc", - "action": "remove" - }, - "3": { - "version": 3, - "hash": "0b664100ea08413ca4ef04665ca910c2cf9535539617ea4ba33687d05cdfe747", - "action": "add" - } - }, - "QueueItem": { - "3": { - "version": 3, - "hash": "3495f406d2c97050ce86be80c230f49b6b846c63b9a9230cbd6631952f2bad0f", - "action": "add" - } - }, - "ActionQueueItem": { - "2": { - "version": 2, - "hash": "6413ed01e949cac169299a43ce40651f9bf8053e408b6942853f8afa8a693b3d", - "action": "add" - } - }, - "ZMQClientConfig": { - "2": { - "version": 2, - "hash": "0f9bc88d56cd6eed6fc75459d1f914aed840c66e1195b9e41cc501b488fef2ed", - "action": "remove" - }, - "3": { - "version": 3, - "hash": "91ce5953cced58e12c576aa5174d5ca0c91981b01cf42edd5283d347baa3390b", - "action": "add" - } - }, - "SyftWorker": { - "1": { - "version": 1, - "hash": "0d5b367162f3ce55ab090cc1b49bd30e50d4eb144e8431eadc679bd0e743aa70", - "action": "add" - } - }, - "WorkerPool": { - "1": { - "version": 1, - "hash": "250699eb4c452fc427995353d5c5ad6245fb3e9fdac8814f8348784816a0733b", - "action": "add" - } - }, - "SyftImageRegistry": { - "1": { - "version": 1, - "hash": "dc83910c91947e3d9eaa3e6f8592237448f0408668c7cca80450b5fcd54722e1", - "action": "add" - } - }, - "UserCode": { - "3": { - "version": 3, - "hash": "90fcae0f556f375ba1e91d2e345f57241660695c6e2b84c8e311df89d09e6c66", - "action": "add" - } - }, - "SubmitUserCode": { - "3": { - "version": 3, - "hash": "a29160c16d2e2620800d42cdcd9f3637d063a570c477a5d05217a2e64b4bb396", - "action": "add" - } - }, - "CreateCustomImageChange": { - "1": { - "version": 1, - "hash": "bc09dca7995938f3b3a2bd9c8b3c2feffc8484df466144a425cb69cadb2ab635", - "action": "add" - } - }, - "CreateCustomWorkerPoolChange": { - "1": { - "version": 1, - "hash": "86894f8ccc037de61f44f9698fd113ba02c3cf3870a3048c00a46e15dcd1941c", - "action": "add" - } - }, - "JobInfo": { - "1": { - "version": 1, - "hash": "cf26eeac3d9254dfa439917493b816341f8a379a77d182bbecba3b7ed2c1d00a", - "action": "add" - } - }, - "User": { - "1": { - "version": 1, - "hash": "078636e64f737e60245b39cf348d30fb006531e80c12b70aa7cf98254e1bb37a", - "action": "remove" - }, - "2": { - "version": 2, - "hash": "ded970c92f202716ed33a2117cf541789f35fad66bd4b1db39da5026b1d7d0e7", - "action": "add" - } - }, - "UserUpdate": { - "1": { - "version": 1, - "hash": "839dd90aeb611e1dc471c8fd6daf230e913465c0625c6a297079cb7f0a271195", - "action": "remove" - }, - "2": { - "version": 2, - "hash": "32cba8fbd786c575f92e26c31384d282e68e3ebfe5c4b0a0e793820b1228d246", - "action": "add" - } - }, - "UserCreate": { - "1": { - "version": 1, - "hash": "dab78b63544ae91c09f9843c323cb237c0a6fcfeb71c1acf5f738e2fcf5c277f", - "action": "remove" - }, - "2": { - "version": 2, - "hash": "2540188c5aaea866914dccff459df6e0f4727108a503414bb1567ff6297d4646", - "action": "add" - } - }, - "UserView": { - "1": { - "version": 1, - "hash": "63289383fe7e7584652f242a4362ce6e2f0ade52f6416ab6149b326a506b0675", - "action": "remove" - }, - "2": { - "version": 2, - "hash": "e410de583bb15bc5af57acef7be55ea5fc56b5b0fc169daa3869f4203c4d7473", - "action": "add" - } - }, - "BlobFile": { - "2": { - "version": 2, - "hash": "f2b29d28fe81a04bf5e946c819010283a9f98a97d50519358bead773865a2e09", - "action": "remove" - }, - "3": { - "version": 3, - "hash": "8f1710c754bb3b39f546b97fd69c4826291398b247976bbc41fa873af431bca9", - "action": "add" - } - }, - "SyftObjectRetrieval": { - "1": { - "version": 1, - "hash": "7ccc62d5b434d2d438b3df661b4d753b0c7c8d593d451d8b86d364da83998c89", - "action": "remove" - }, - "3": { - "version": 3, - "hash": "952958e9afae007bef3cb89aa15be95dddc4c310e3a8ce4191576f90ac6fcbc8", - "action": "add" - } - }, - "ActionFileData": { - "1": { - "version": 1, - "hash": "1f32d94b75b0a6b4e86cec93d94aa905738219e3e7e75f51dd335ee832a6ed3e", - "action": "remove" - } - }, - "SeaweedFSBlobDeposit": { - "2": { - "version": 2, - "hash": "07d84a95324d95d9c868cd7d1c33c908f77aa468671d76c144586aab672bcbb5", - "action": "add" - } - }, - "mock_type": { - "1": { - "version": 1, - "hash": "438b791584a3a832288d9b3038a43a753ee9febd065deff45ae2194a33ba4513", - "action": "add" - } - }, - "71beaeec12414862889743793403de2e": { - "1": { - "version": 1, - "hash": "59c901096d894c20f87a499eef6d2b7e579263f63bafcdaa249459e6a672340f", - "action": "add" - } - }, - "MockStoreConfig": { - "1": { - "version": 1, - "hash": "77750b74b93749a150a3632cf76c8052dcaa246e68c54b54fac3e14c7d78b3d3", - "action": "add" - } - }, - "NodeActionData": { - "1": { - "version": 1, - "hash": "0635d37c81929469f799e1f0889a2e498818ddfed82441795dc6649ff51ccd3b", - "action": "add" - } - }, - "NodeActionDataUpdate": { - "1": { - "version": 1, - "hash": "f0305c6ae72a464f7779fa17a79d405ae89685fc4d38330d4dfded7a572a6e3a", - "action": "add" - } - }, - "InMemoryGraphConfig": { - "1": { - "version": 1, - "hash": "efb9a7a9b8468db54f616dec62bcdcaf8e93637bc0490759f27bcb9ae746c075", - "action": "add" - } - }, - "MockWrapper": { - "1": { - "version": 1, - "hash": "67e74d0ff1db448b75b2ed73ee1e162200f972c8d7c29eb66fcd3095671181e8", - "action": "add" - } - }, - "base_stash_mock_object_type": { - "1": { - "version": 1, - "hash": "5a44200e21f7ce85346a92413f01944fdf0a54e60729244d23f356aa56414968", - "action": "add" - } - } - } - } -} diff --git a/packages/syft/src/syft/protocol/8234fa441ef844f7bf2cb1c05cad9a7a.json b/packages/syft/src/syft/protocol/8234fa441ef844f7bf2cb1c05cad9a7a.json deleted file mode 100644 index a52fbae529c..00000000000 --- a/packages/syft/src/syft/protocol/8234fa441ef844f7bf2cb1c05cad9a7a.json +++ /dev/null @@ -1,304 +0,0 @@ -{ - "1": { - "release_name": "0.8.2.json" - }, - "2": { - "release_name": "0.8.3.json" - }, - "dev": { - "object_versions": { - "SyftWorkerImage": { - "1": { - "version": 1, - "hash": "2a9585b6a286e24f1a9f3f943d0128730cf853edc549184dc1809d19e1eec54b", - "action": "add" - } - }, - "ActionDataLink": { - "1": { - "version": 1, - "hash": "10bf94e99637695f1ba283f0b10e70743a4ebcb9ee75aefb1a05e6d6e1d21a71", - "action": "add" - } - }, - "ObjectNotReady": { - "1": { - "version": 1, - "hash": "88207988639b11eaca686b6e079616d9caecc3dbc2a8112258e0f39ee5c3e113", - "action": "add" - } - }, - "JobItem": { - "3": { - "version": 3, - "hash": "5b93a59e28574691339d22826d5650969336a2e930b93d6b3fe6d5409ca0cfc4", - "action": "add" - } - }, - "SeaweedSecureFilePathLocation": { - "2": { - "version": 2, - "hash": "5fd63fed2a4efba8c2b6c7a7b5e9b5939181781c331230896aa130b6fd558739", - "action": "add" - } - }, - "AzureSecureFilePathLocation": { - "1": { - "version": 1, - "hash": "1bb15f3f9d7082779f1c9f58de94011487924cb8a8c9c2ec18fd7c161c27fd0e", - "action": "add" - } - }, - "RemoteConfig": { - "1": { - "version": 1, - "hash": "ad7bc4780a8ad52e14ce68601852c93d2fe07bda489809cad7cae786d2461754", - "action": "add" - } - }, - "AzureRemoteConfig": { - "1": { - "version": 1, - "hash": "c05c6caa27db4e385c642536d4b0ecabc0c71e91220d2e6ce21a2761ca68a673", - "action": "add" - } - }, - "BlobRetrievalByURL": { - "2": { - "version": 2, - "hash": "8059ee03016c4d74e408dad9529e877f91829672e0cc42d8cfff9c8e14058adc", - "action": "remove" - }, - "3": { - "version": 3, - "hash": "0b664100ea08413ca4ef04665ca910c2cf9535539617ea4ba33687d05cdfe747", - "action": "add" - } - }, - "QueueItem": { - "3": { - "version": 3, - "hash": "3495f406d2c97050ce86be80c230f49b6b846c63b9a9230cbd6631952f2bad0f", - "action": "add" - } - }, - "ActionQueueItem": { - "2": { - "version": 2, - "hash": "6413ed01e949cac169299a43ce40651f9bf8053e408b6942853f8afa8a693b3d", - "action": "add" - } - }, - "ZMQClientConfig": { - "2": { - "version": 2, - "hash": "0f9bc88d56cd6eed6fc75459d1f914aed840c66e1195b9e41cc501b488fef2ed", - "action": "remove" - }, - "3": { - "version": 3, - "hash": "91ce5953cced58e12c576aa5174d5ca0c91981b01cf42edd5283d347baa3390b", - "action": "add" - } - }, - "SyftWorker": { - "1": { - "version": 1, - "hash": "0d5b367162f3ce55ab090cc1b49bd30e50d4eb144e8431eadc679bd0e743aa70", - "action": "add" - } - }, - "WorkerPool": { - "1": { - "version": 1, - "hash": "250699eb4c452fc427995353d5c5ad6245fb3e9fdac8814f8348784816a0733b", - "action": "add" - } - }, - "SyftImageRegistry": { - "1": { - "version": 1, - "hash": "dc83910c91947e3d9eaa3e6f8592237448f0408668c7cca80450b5fcd54722e1", - "action": "add" - } - }, - "UserCode": { - "3": { - "version": 3, - "hash": "90fcae0f556f375ba1e91d2e345f57241660695c6e2b84c8e311df89d09e6c66", - "action": "add" - } - }, - "SubmitUserCode": { - "3": { - "version": 3, - "hash": "a29160c16d2e2620800d42cdcd9f3637d063a570c477a5d05217a2e64b4bb396", - "action": "add" - } - }, - "CreateCustomImageChange": { - "1": { - "version": 1, - "hash": "bc09dca7995938f3b3a2bd9c8b3c2feffc8484df466144a425cb69cadb2ab635", - "action": "add" - } - }, - "CreateCustomWorkerPoolChange": { - "1": { - "version": 1, - "hash": "86894f8ccc037de61f44f9698fd113ba02c3cf3870a3048c00a46e15dcd1941c", - "action": "add" - } - }, - "JobInfo": { - "1": { - "version": 1, - "hash": "cf26eeac3d9254dfa439917493b816341f8a379a77d182bbecba3b7ed2c1d00a", - "action": "add" - } - }, - "User": { - "1": { - "version": 1, - "hash": "078636e64f737e60245b39cf348d30fb006531e80c12b70aa7cf98254e1bb37a", - "action": "remove" - }, - "2": { - "version": 2, - "hash": "ded970c92f202716ed33a2117cf541789f35fad66bd4b1db39da5026b1d7d0e7", - "action": "add" - } - }, - "UserUpdate": { - "1": { - "version": 1, - "hash": "839dd90aeb611e1dc471c8fd6daf230e913465c0625c6a297079cb7f0a271195", - "action": "remove" - }, - "2": { - "version": 2, - "hash": "32cba8fbd786c575f92e26c31384d282e68e3ebfe5c4b0a0e793820b1228d246", - "action": "add" - } - }, - "UserCreate": { - "1": { - "version": 1, - "hash": "dab78b63544ae91c09f9843c323cb237c0a6fcfeb71c1acf5f738e2fcf5c277f", - "action": "remove" - }, - "2": { - "version": 2, - "hash": "2540188c5aaea866914dccff459df6e0f4727108a503414bb1567ff6297d4646", - "action": "add" - } - }, - "UserView": { - "1": { - "version": 1, - "hash": "63289383fe7e7584652f242a4362ce6e2f0ade52f6416ab6149b326a506b0675", - "action": "remove" - }, - "2": { - "version": 2, - "hash": "e410de583bb15bc5af57acef7be55ea5fc56b5b0fc169daa3869f4203c4d7473", - "action": "add" - } - }, - "BlobFile": { - "2": { - "version": 2, - "hash": "f2b29d28fe81a04bf5e946c819010283a9f98a97d50519358bead773865a2e09", - "action": "remove" - }, - "3": { - "version": 3, - "hash": "8f1710c754bb3b39f546b97fd69c4826291398b247976bbc41fa873af431bca9", - "action": "add" - } - }, - "SyftObjectRetrieval": { - "1": { - "version": 1, - "hash": "7ccc62d5b434d2d438b3df661b4d753b0c7c8d593d451d8b86d364da83998c89", - "action": "remove" - }, - "3": { - "version": 3, - "hash": "952958e9afae007bef3cb89aa15be95dddc4c310e3a8ce4191576f90ac6fcbc8", - "action": "add" - } - }, - "ActionFileData": { - "1": { - "version": 1, - "hash": "1f32d94b75b0a6b4e86cec93d94aa905738219e3e7e75f51dd335ee832a6ed3e", - "action": "remove" - } - }, - "SeaweedFSBlobDeposit": { - "2": { - "version": 2, - "hash": "07d84a95324d95d9c868cd7d1c33c908f77aa468671d76c144586aab672bcbb5", - "action": "add" - } - }, - "mock_type": { - "1": { - "version": 1, - "hash": "438b791584a3a832288d9b3038a43a753ee9febd065deff45ae2194a33ba4513", - "action": "add" - } - }, - "0c4ba3ae77614f5597dd7e162eef2f7d": { - "1": { - "version": 1, - "hash": "b8618818d6b35d9734fd91d701d98757d3bfd0b261fb6946e350822f91c00d27", - "action": "add" - } - }, - "MockStoreConfig": { - "1": { - "version": 1, - "hash": "77750b74b93749a150a3632cf76c8052dcaa246e68c54b54fac3e14c7d78b3d3", - "action": "add" - } - }, - "NodeActionData": { - "1": { - "version": 1, - "hash": "0635d37c81929469f799e1f0889a2e498818ddfed82441795dc6649ff51ccd3b", - "action": "add" - } - }, - "NodeActionDataUpdate": { - "1": { - "version": 1, - "hash": "f0305c6ae72a464f7779fa17a79d405ae89685fc4d38330d4dfded7a572a6e3a", - "action": "add" - } - }, - "InMemoryGraphConfig": { - "1": { - "version": 1, - "hash": "efb9a7a9b8468db54f616dec62bcdcaf8e93637bc0490759f27bcb9ae746c075", - "action": "add" - } - }, - "MockWrapper": { - "1": { - "version": 1, - "hash": "67e74d0ff1db448b75b2ed73ee1e162200f972c8d7c29eb66fcd3095671181e8", - "action": "add" - } - }, - "base_stash_mock_object_type": { - "1": { - "version": 1, - "hash": "5a44200e21f7ce85346a92413f01944fdf0a54e60729244d23f356aa56414968", - "action": "add" - } - } - } - } -} From 177807002c6849cb8830ffeef7ac9f873651a955 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 19 Feb 2024 15:57:55 +0700 Subject: [PATCH 12/42] [refactor] done fixing mypy issues for `service/queue` --- .pre-commit-config.yaml | 2 +- .../syft/src/syft/service/queue/base_queue.py | 6 +- .../syft/src/syft/service/queue/zmq_queue.py | 158 ++++++++++-------- 3 files changed, 90 insertions(+), 76 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cd54f42b1f6..b71cc95d5fd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -173,7 +173,7 @@ repos: name: "mypy: syft" always_run: true # files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util|^packages/syft/src/syft/service" - files: "^packages/syft/src/syft/service/queue/queue.py|^packages/syft/src/syft/service/queue/base_queue.py" + files: "^packages/syft/src/syft/service/queue" args: [ "--follow-imports=skip", "--ignore-missing-imports", diff --git a/packages/syft/src/syft/service/queue/base_queue.py b/packages/syft/src/syft/service/queue/base_queue.py index 6e2d811fe22..1fe914bf8a6 100644 --- a/packages/syft/src/syft/service/queue/base_queue.py +++ b/packages/syft/src/syft/service/queue/base_queue.py @@ -47,11 +47,15 @@ def close(self) -> None: @serializable() class QueueProducer: - address: str queue_name: str + @property + def address(self) -> str: + raise NotImplementedError + def send( self, + worker: bytes, message: Any, ) -> None: raise NotImplementedError diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index ac6af365f96..9fc393c1b11 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -7,10 +7,12 @@ import time from time import sleep from typing import Any +from typing import Callable from typing import DefaultDict from typing import Dict from typing import List from typing import Optional +from typing import Type from typing import Union # third party @@ -48,6 +50,7 @@ from .base_queue import QueueConsumer from .base_queue import QueueProducer from .queue_stash import ActionQueueItem +from .queue_stash import QueueStash from .queue_stash import Status # Producer/Consumer heartbeat interval (in seconds) @@ -84,12 +87,12 @@ class QueueMsgProtocol: class Timeout: def __init__(self, offset_sec: float): self.__offset = float(offset_sec) - self.__next_ts = 0 + self.__next_ts: float = 0.0 self.reset() @property - def next_ts(self) -> Union[float, int]: + def next_ts(self) -> float: return self.__next_ts def reset(self) -> None: @@ -103,10 +106,17 @@ def now() -> float: return time.time() +class Service: + def __init__(self, name: str) -> None: + self.name = name + self.requests: list[bytes] = [] + self.waiting: list[Worker] = [] # List of waiting workers + + class Worker(SyftBaseModel): address: bytes identity: bytes - service: Optional[str] = None + service: Optional[Service] = None syft_worker_id: Optional[UID] = None expiry_t: Timeout = Timeout(WORKER_TIMEOUT_SEC) @@ -116,23 +126,16 @@ def set_syft_worker_id(cls, v: Any, values: Any) -> Union[UID, Any]: return UID(v) return v - def has_expired(self): + def has_expired(self) -> bool: return self.expiry_t.has_expired() - def get_expiry(self) -> int: + def get_expiry(self) -> float: return self.expiry_t.next_ts - def reset_expiry(self): + def reset_expiry(self) -> None: self.expiry_t.reset() -class Service: - def __init__(self, name: str) -> None: - self.name = name - self.requests = [] - self.waiting = [] # List of waiting workers - - @serializable() class ZMQProducer(QueueProducer): INTERNAL_SERVICE_PREFIX = b"mmi." @@ -140,7 +143,7 @@ class ZMQProducer(QueueProducer): def __init__( self, queue_name: str, - queue_stash, + queue_stash: QueueStash, worker_stash: WorkerStash, port: int, context: AuthedServiceContext, @@ -155,14 +158,14 @@ def __init__( self.post_init() @property - def address(self): + def address(self) -> str: return get_queue_address(self.port) - def post_init(self): + def post_init(self) -> None: """Initialize producer state.""" - self.services = {} - self.workers = {} + self.services: dict[str, Service] = {} + self.workers: dict[bytes, Worker] = {} self.waiting: List[Worker] = [] self.heartbeat_t = Timeout(HEARTBEAT_INTERVAL_SEC) self.context = zmq.Context(1) @@ -172,10 +175,10 @@ def post_init(self): self.poll_workers = zmq.Poller() self.poll_workers.register(self.socket, zmq.POLLIN) self.bind(f"tcp://*:{self.port}") - self.thread: threading.Thread = None - self.producer_thread: threading.Thread = None + self.thread: Optional[threading.Thread] = None + self.producer_thread: Optional[threading.Thread] = None - def close(self): + def close(self) -> None: self._stop.set() try: @@ -197,10 +200,10 @@ def close(self): self._stop.clear() @property - def action_service(self): + def action_service(self) -> Callable: return self.auth_context.node.get_service("ActionService") - def contains_unresolved_action_objects(self, arg, recursion=0): + def contains_unresolved_action_objects(self, arg: Any, recursion: int = 0) -> bool: """recursively check collections for unresolved action objects""" if isinstance(arg, UID): arg = self.action_service.get(self.auth_context, arg).ok() @@ -236,7 +239,7 @@ def contains_unresolved_action_objects(self, arg, recursion=0): logger.exception("Failed to resolve action objects. {}", e) return True - def unwrap_nested_actionobjects(self, data): + def unwrap_nested_actionobjects(self, data: Any) -> Any: """recursively unwraps nested action objects""" if isinstance(data, List): @@ -258,7 +261,7 @@ def unwrap_nested_actionobjects(self, data): return nested_res return data - def preprocess_action_arg(self, arg): + def preprocess_action_arg(self, arg: Any) -> None: res = self.action_service.get(context=self.auth_context, uid=arg) if res.is_err(): return arg @@ -270,7 +273,7 @@ def preprocess_action_arg(self, arg): context=self.auth_context, action_object=new_action_object ) - def read_items(self): + def read_items(self) -> None: while True: if self._stop.is_set(): break @@ -311,7 +314,7 @@ def read_items(self): ) worker_pool = worker_pool.ok() service_name = worker_pool.name - service: Service = self.services.get(service_name) + service: Optional[Service] = self.services.get(service_name) # Skip adding message if corresponding service/pool # is not registered. @@ -339,30 +342,30 @@ def read_items(self): # else decrease retry count and mark status as CREATED. pass - def run(self): + def run(self) -> None: self.thread = threading.Thread(target=self._run) self.thread.start() self.producer_thread = threading.Thread(target=self.read_items) self.producer_thread.start() - def send(self, worker: bytes, message: Union[bytes, List[bytes]]): + def send(self, worker: bytes, message: Union[bytes, List[bytes]]) -> None: worker_obj = self.require_worker(worker) self.send_to_worker(worker=worker_obj, msg=message) - def bind(self, endpoint): + def bind(self, endpoint: str) -> None: """Bind producer to endpoint.""" self.socket.bind(endpoint) logger.info("Producer endpoint: {}", endpoint) - def send_heartbeats(self): + def send_heartbeats(self) -> None: """Send heartbeats to idle workers if it's time""" if self.heartbeat_t.has_expired(): for worker in self.waiting: self.send_to_worker(worker, QueueMsgProtocol.W_HEARTBEAT, None, None) self.heartbeat_t.reset() - def purge_workers(self): + def purge_workers(self) -> None: """Look for & kill expired workers. Workers are oldest to most recent, so we stop at the first alive worker. @@ -381,7 +384,7 @@ def purge_workers(self): def update_consumer_state_for_worker( self, syft_worker_id: UID, consumer_state: ConsumerState - ): + ) -> None: if self.worker_stash is None: logger.error( f"Worker stash is not defined for ZMQProducer : {self.queue_name} - {self.id}" @@ -405,18 +408,18 @@ def update_consumer_state_for_worker( f"Failed to update consumer state for worker id: {syft_worker_id}. Error: {e}" ) - def worker_waiting(self, worker: Worker): + def worker_waiting(self, worker: Worker) -> None: """This worker is now waiting for work.""" # Queue to broker and service waiting lists if worker not in self.waiting: self.waiting.append(worker) - if worker not in worker.service.waiting: + if worker.service is not None and worker not in worker.service.waiting: worker.service.waiting.append(worker) worker.reset_expiry() self.update_consumer_state_for_worker(worker.syft_worker_id, ConsumerState.IDLE) self.dispatch(worker.service, None) - def dispatch(self, service: Service, msg: bytes): + def dispatch(self, service: Service, msg: bytes) -> None: """Dispatch requests to waiting workers as possible""" if msg is not None: # Queue message if any service.requests.append(msg) @@ -432,10 +435,10 @@ def dispatch(self, service: Service, msg: bytes): def send_to_worker( self, worker: Worker, - command: QueueMsgProtocol = QueueMsgProtocol.W_REQUEST, - option: bytes = None, + command: bytes = QueueMsgProtocol.W_REQUEST, + option: Optional[bytes] = None, msg: Optional[Union[bytes, list]] = None, - ): + ) -> None: """Send message to worker. If message is provided, sends that message. @@ -456,7 +459,7 @@ def send_to_worker( with ZMQ_SOCKET_LOCK: self.socket.send_multipart(msg) - def _run(self): + def _run(self) -> None: while True: if self._stop.is_set(): return @@ -488,7 +491,7 @@ def _run(self): self.send_heartbeats() self.purge_workers() - def require_worker(self, address): + def require_worker(self, address: bytes) -> Worker: """Finds the worker (creates if necessary).""" identity = hexlify(address) worker = self.workers.get(identity) @@ -497,7 +500,7 @@ def require_worker(self, address): self.workers[identity] = worker return worker - def process_worker(self, address: bytes, msg: List[bytes]): + def process_worker(self, address: bytes, msg: List[bytes]) -> None: command = msg.pop(0) worker_ready = hexlify(address) in self.workers @@ -515,19 +518,26 @@ def process_worker(self, address: bytes, msg: List[bytes]): self.delete_worker(worker, True) else: # Attach worker to service and mark as idle - if service_name not in self.services: + if service_name in self.services: + service: Optional[Service] = self.services.get(service_name) + else: service = Service(service_name) self.services[service_name] = service + if service is not None: + worker.service = service + logger.info( + "New Worker service={}, id={}, uid={}", + service.name, + worker.identity, + worker.syft_worker_id, + ) else: - service = self.services.get(service_name) - worker.service = service + logger.info( + "New Worker service=None, id={}, uid={}", + worker.identity, + worker.syft_worker_id, + ) worker.syft_worker_id = UID(syft_worker_id) - logger.info( - "New Worker service={} id={} uid={}", - service.name, - worker.identity, - worker.syft_worker_id, - ) self.worker_waiting(worker) elif QueueMsgProtocol.W_HEARTBEAT == command: @@ -546,7 +556,7 @@ def process_worker(self, address: bytes, msg: List[bytes]): else: logger.error("Invalid command: {}", command) - def delete_worker(self, worker: Worker, disconnect: bool): + def delete_worker(self, worker: Worker, disconnect: bool) -> None: """Deletes worker from all data structures, and deletes worker.""" if disconnect: self.send_to_worker(worker, QueueMsgProtocol.W_DISCONNECT, None, None) @@ -564,7 +574,7 @@ def delete_worker(self, worker: Worker, disconnect: bool): ) @property - def alive(self): + def alive(self) -> bool: return not self.socket.closed @@ -594,10 +604,10 @@ def __init__( self.worker_stash = worker_stash self.post_init() - def reconnect_to_producer(self): + def reconnect_to_producer(self) -> None: """Connect or reconnect to producer""" if self.socket: - self.poller.unregister(self.socket) + self.poller.unregister(self.socket) # type: ignore[unreachable] self.socket.close() self.socket = self.context.socket(zmq.DEALER) self.socket.linger = 0 @@ -614,13 +624,13 @@ def reconnect_to_producer(self): [str(self.syft_worker_id).encode()], ) - def post_init(self): - self.thread = None + def post_init(self) -> None: + self.thread: Optional[threading.Thread] = None self.heartbeat_t = Timeout(HEARTBEAT_INTERVAL_SEC) self.producer_ping_t = Timeout(PRODUCER_TIMEOUT_SEC) self.reconnect_to_producer() - def close(self): + def close(self) -> None: self._stop.set() try: self.poller.unregister(self.socket) @@ -639,7 +649,7 @@ def send_to_producer( command: str, option: Optional[bytes] = None, msg: Optional[Union[bytes, list]] = None, - ): + ) -> None: """Send message to producer. If no msg is provided, creates one internally @@ -657,7 +667,7 @@ def send_to_producer( with ZMQ_SOCKET_LOCK: self.socket.send_multipart(msg) - def _run(self): + def _run(self) -> None: """Send reply, if any, to producer and wait for next request.""" try: while True: @@ -728,33 +738,33 @@ def _run(self): logger.info("Worker finished") - def set_producer_alive(self): + def set_producer_alive(self) -> None: self.producer_ping_t.reset() def is_producer_alive(self) -> bool: # producer timer is within timeout return not self.producer_ping_t.has_expired() - def send_heartbeat(self): + def send_heartbeat(self) -> None: if self.heartbeat_t.has_expired() and self.is_producer_alive(): self.send_to_producer(QueueMsgProtocol.W_HEARTBEAT) self.heartbeat_t.reset() - def run(self): + def run(self) -> None: self.thread = threading.Thread(target=self._run) self.thread.start() - def associate_job(self, message: Frame): + def associate_job(self, message: Frame) -> None: try: queue_item = _deserialize(message, from_bytes=True) self._set_worker_job(queue_item.job_id) except Exception as e: logger.exception("Could not associate job. {}", e) - def clear_job(self): + def clear_job(self) -> None: self._set_worker_job(None) - def _set_worker_job(self, job_id: Optional[UID]): + def _set_worker_job(self, job_id: Optional[UID]) -> None: if self.worker_stash is not None: consumer_state = ( ConsumerState.IDLE if job_id is None else ConsumerState.CONSUMING @@ -770,7 +780,7 @@ def _set_worker_job(self, job_id: Optional[UID]): ) @property - def alive(self): + def alive(self) -> bool: return not self.socket.closed and self.is_producer_alive() @@ -812,14 +822,14 @@ class ZMQClientConfig(SyftObject, QueueClientConfig): @migrate(ZMQClientConfig, ZMQClientConfigV1) -def downgrade_zmqclientconfig_v2_to_v1(): +def downgrade_zmqclientconfig_v2_to_v1() -> list[Callable]: return [ drop(["queue_port", "create_producer", "n_consumers"]), ] @migrate(ZMQClientConfigV1, ZMQClientConfig) -def upgrade_zmqclientconfig_v1_to_v2(): +def upgrade_zmqclientconfig_v1_to_v2() -> list[Callable]: return [ make_set_default("queue_port", None), make_set_default("create_producer", False), @@ -841,7 +851,7 @@ def __init__(self, config: ZMQClientConfig) -> None: self.config = config @staticmethod - def _get_free_tcp_port(host: str): + def _get_free_tcp_port(host: str) -> int: with socketserver.TCPServer((host, 0), None) as s: free_port = s.server_address[1] return free_port @@ -850,9 +860,9 @@ def add_producer( self, queue_name: str, port: Optional[int] = None, - queue_stash=None, + queue_stash: Optional[QueueStash] = None, worker_stash: Optional[WorkerStash] = None, - context=None, + context: AuthedServiceContext = None, ) -> ZMQProducer: """Add a producer of a queue. @@ -954,7 +964,7 @@ def purge_queue(self, queue_name: str) -> Union[SyftError, SyftSuccess]: producer.close() # add a new connection - self.add_producer(queue_name=queue_name, address=producer.address) + self.add_producer(queue_name=queue_name, address=producer.address) # type: ignore return SyftSuccess(message=f"Queue: {queue_name} successfully purged") @@ -969,7 +979,7 @@ def purge_all(self) -> Union[SyftError, SyftSuccess]: class ZMQQueueConfig(QueueConfig): def __init__( self, - client_type: Optional[Union[ZMQClient, Any]] = None, + client_type: Optional[Type[ZMQClient]] = None, client_config: Optional[ZMQClientConfig] = None, thread_workers: bool = False, ): From 5f23d1ad72fc367d420b3cb00bdd8f71ac3db8ff Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 19 Feb 2024 20:57:03 +0700 Subject: [PATCH 13/42] [refactor] continue fixing mypy issues in --- .pre-commit-config.yaml | 4 +- .../service/action/action_graph_service.py | 2 +- .../syft/service/dataset/dataset_service.py | 23 +++-- .../syft/src/syft/service/job/job_service.py | 9 +- .../syft/src/syft/service/job/job_stash.py | 95 +++++++++++------- .../syft/src/syft/service/project/project.py | 99 +++++++++++-------- .../syft/src/syft/service/queue/zmq_queue.py | 10 +- .../service/worker/worker_pool_service.py | 51 ++++++---- 8 files changed, 182 insertions(+), 111 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b71cc95d5fd..78d7205afb6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -172,8 +172,8 @@ repos: - id: mypy name: "mypy: syft" always_run: true - # files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util|^packages/syft/src/syft/service" - files: "^packages/syft/src/syft/service/queue" + files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util|^packages/syft/src/syft/service" + # files: "^packages/syft/src/syft/" args: [ "--follow-imports=skip", "--ignore-missing-imports", diff --git a/packages/syft/src/syft/service/action/action_graph_service.py b/packages/syft/src/syft/service/action/action_graph_service.py index df09578dcd5..6e06a9c84a0 100644 --- a/packages/syft/src/syft/service/action/action_graph_service.py +++ b/packages/syft/src/syft/service/action/action_graph_service.py @@ -39,7 +39,7 @@ def __init__(self, store: ActionGraphStore): @service_method(path="graph.add_action", name="add_action") def add_action( self, context: AuthedServiceContext, action: Action - ) -> Union[NodeActionData, SyftError]: + ) -> Union[tuple[NodeActionData, NodeActionData], SyftError]: # Create a node for the action input_uids, output_uid = self._extract_input_and_output_from_action( action=action diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py index a06a3c72f74..2971e252398 100644 --- a/packages/syft/src/syft/service/dataset/dataset_service.py +++ b/packages/syft/src/syft/service/dataset/dataset_service.py @@ -99,10 +99,16 @@ def add( ) if result.is_err(): return SyftError(message=str(result.err())) - return SyftSuccess( - message=f"Dataset uploaded to '{context.node.name}'. " - f"To see the datasets uploaded by a client on this node, use command `[your_client].datasets`" - ) + if context.node is not None: + return SyftSuccess( + message=f"Dataset uploaded to '{context.node.name}'. " + f"To see the datasets uploaded by a client on this node, use command `[your_client].datasets`" + ) + else: + return SyftSuccess( + message="Dataset uploaded not to a node." + "To see the datasets uploaded by a client on this node, use command `[your_client].datasets`" + ) @service_method( path="dataset.get_all", @@ -124,7 +130,8 @@ def get_all( datasets = result.ok() for dataset in datasets: - dataset.node_uid = context.node.id + if context.node is not None: + dataset.node_uid = context.node.id return _paginate_dataset_collection( datasets=datasets, page_size=page_size, page_index=page_index @@ -162,7 +169,8 @@ def get_by_id( result = self.stash.get_by_uid(context.credentials, uid=uid) if result.is_ok(): dataset = result.ok() - dataset.node_uid = context.node.id + if context.node is not None: + dataset.node_uid = context.node.id return dataset return SyftError(message=result.err()) @@ -175,7 +183,8 @@ def get_by_action_id( if result.is_ok(): datasets = result.ok() for dataset in datasets: - dataset.node_uid = context.node.id + if context.node is not None: + dataset.node_uid = context.node.id return datasets return SyftError(message=result.err()) diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index aa8267d736c..e3d39e4179f 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -86,6 +86,8 @@ def restart( res = self.stash.get_by_uid(context.credentials, uid=uid) if res.is_err(): return SyftError(message=res.err()) + if context.node is None: + return SyftError(message=f"context {context}'s node is None") job = res.ok() job.status = JobStatus.CREATED @@ -107,9 +109,9 @@ def restart( context.node.queue_stash.set_placeholder(context.credentials, queue_item) context.node.job_stash.set(context.credentials, job) + log_service = context.node.get_service("logservice") result = log_service.restart(context, job.log_id) - if result.is_err(): return SyftError(message=str(result.err())) @@ -185,6 +187,8 @@ def get_active(self, context: AuthedServiceContext) -> Union[List[Job], SyftErro def create_job_for_user_code_id( self, context: AuthedServiceContext, user_code_id: UID ) -> Union[Job, SyftError]: + if context.node is None: + return SyftError(message=f"context {context}'s node is None") job = Job( id=UID(), node_uid=context.node.id, @@ -195,7 +199,6 @@ def create_job_for_user_code_id( job_pid=None, user_code_id=user_code_id, ) - user_code_service = context.node.get_service("usercodeservice") user_code = user_code_service.get_by_uid(context=context, uid=user_code_id) if isinstance(user_code, SyftError): @@ -207,6 +210,8 @@ def create_job_for_user_code_id( ) self.stash.set(context.credentials, job, add_permissions=[permission]) + if context.node is None: + return SyftError(message=f"context {context}'s node is None") log_service = context.node.get_service("logservice") res = log_service.add(context, job.log_id) if isinstance(res, SyftError): diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 49e2ff8cf25..c69e408a625 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -3,6 +3,7 @@ from datetime import timedelta from enum import Enum from typing import Any +from typing import Callable from typing import Dict from typing import List from typing import Optional @@ -13,12 +14,15 @@ from result import Err from result import Ok from result import Result +from typing_extensions import Self # relative from ...client.api import APIRegistry from ...client.api import SyftAPICall from ...node.credentials import SyftVerifyKey from ...serde.serializable import serializable +from ...service.queue.queue_stash import QueueItem +from ...service.worker.worker_pool import SyftWorker from ...store.document_store import BaseStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey @@ -46,6 +50,7 @@ from ..response import SyftError from ..response import SyftNotReady from ..response import SyftSuccess +from ..user.user import UserView @serializable() @@ -140,7 +145,7 @@ def check_user_code_id(cls, values: dict) -> dict: return values @property - def action_display_name(self): + def action_display_name(self) -> str: if self.action is None: return "action" else: @@ -150,18 +155,24 @@ def action_display_name(self): return self.action.job_display_name @property - def time_remaining_string(self): + def time_remaining_string(self) -> Optional[str]: # update state self.fetch() - percentage = round((self.current_iter / self.n_iters) * 100) - blocks_filled = round(percentage / 20) - blocks_empty = 5 - blocks_filled - blocks_filled_str = "█" * blocks_filled - blocks_empty_str = "  " * blocks_empty - return f"{percentage}% |{blocks_filled_str}{blocks_empty_str}|\n{self.current_iter}/{self.n_iters}\n" + if ( + self.current_iter is not None + and self.n_iters is not None + and self.n_iters != 0 + ): + percentage = round((self.current_iter / self.n_iters) * 100) + blocks_filled = round(percentage / 20) + blocks_empty = 5 - blocks_filled + blocks_filled_str = "█" * blocks_filled + blocks_empty_str = "  " * blocks_empty + return f"{percentage}% |{blocks_filled_str}{blocks_empty_str}|\n{self.current_iter}/{self.n_iters}\n" + return None @property - def worker(self): + def worker(self) -> Union[SyftWorker, SyftError]: api = APIRegistry.api_for( node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, @@ -169,7 +180,7 @@ def worker(self): return api.services.worker.get(self.job_worker_id) @property - def eta_string(self): + def eta_string(self) -> Optional[str]: if ( self.current_iter is None or self.current_iter == 0 @@ -178,7 +189,7 @@ def eta_string(self): ): return None - def format_timedelta(local_timedelta): + def format_timedelta(local_timedelta: timedelta) -> str: total_seconds = int(local_timedelta.total_seconds()) hours, leftover = divmod(total_seconds, 3600) minutes, seconds = divmod(leftover, 60) @@ -209,7 +220,7 @@ def format_timedelta(local_timedelta): return f"[{time_passed_str}<{time_remaining_str}]\n{iter_duration_str}" @property - def progress(self) -> str: + def progress(self) -> Optional[str]: if self.status in [JobStatus.PROCESSING, JobStatus.COMPLETED]: if self.current_iter is None: return "" @@ -241,7 +252,7 @@ def apply_info(self, info: "JobInfo") -> None: if info.includes_result: self.result = info.result - def restart(self, kill=False) -> None: + def restart(self, kill: bool = False) -> None: if kill: self.kill() self.fetch() @@ -269,14 +280,14 @@ def restart(self, kill=False) -> None: print( "Job is running or scheduled, if you want to kill it use job.kill() first" ) + return None - def kill(self) -> Union[None, SyftError]: + def kill(self) -> Optional[SyftError]: if self.job_pid is not None: api = APIRegistry.api_for( node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, ) - call = SyftAPICall( node_uid=self.node_uid, path="job.kill", @@ -285,6 +296,7 @@ def kill(self) -> Union[None, SyftError]: blocking=True, ) api.make_call(call) + return None else: return SyftError( message="Job is not running or isn't running in multiprocessing mode." @@ -312,7 +324,7 @@ def fetch(self) -> None: self.current_iter = job.current_iter @property - def subjobs(self): + def subjobs(self) -> Union[list[QueueItem], SyftError]: api = APIRegistry.api_for( node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, @@ -320,14 +332,16 @@ def subjobs(self): return api.services.job.get_subjobs(self.id) @property - def owner(self): + def owner(self) -> Union[UserView, SyftError]: api = APIRegistry.api_for( node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, ) return api.services.user.get_current_user(self.id) - def logs(self, stdout=True, stderr=True, _print=True): + def logs( + self, stdout: bool = True, stderr: bool = True, _print: bool = True + ) -> Optional[str]: api = APIRegistry.api_for( node_uid=self.syft_node_location, user_verify_key=self.syft_client_verify_key, @@ -355,19 +369,20 @@ def logs(self, stdout=True, stderr=True, _print=True): return results_str else: print(results_str) + return None # def __repr__(self) -> str: # return f": {self.status}" def _coll_repr_(self) -> Dict[str, Any]: logs = self.logs(_print=False, stderr=False) - log_lines = logs.split("\n") + if logs is not None: + log_lines = logs.split("\n") subjobs = self.subjobs if len(log_lines) > 2: logs = f"... ({len(log_lines)} lines)\n" + "\n".join(log_lines[-2:]) - else: - logs = logs + created_time = self.creation_time[:-7] if self.creation_time is not None else "" return { "status": f"{self.action_display_name}: {self.status}" + ( @@ -377,7 +392,7 @@ def _coll_repr_(self) -> Dict[str, Any]: ), "progress": self.progress, "eta": self.eta_string, - "created": f"{self.creation_time[:-7]} by {self.owner.email}", + "created": f"{created_time} by {self.owner.email}", "logs": logs, # "result": result, # "parent_id": str(self.parent_job_id) if self.parent_job_id else "-", @@ -385,15 +400,16 @@ def _coll_repr_(self) -> Dict[str, Any]: } @property - def has_parent(self): + def has_parent(self) -> bool: return self.parent_job_id is not None def _repr_markdown_(self) -> str: _ = self.resolve logs = self.logs(_print=False) - logs_w_linenr = "\n".join( - [f"{i} {line}" for i, line in enumerate(logs.rstrip().split("\n"))] - ) + if logs is not None: + logs_w_linenr = "\n".join( + [f"{i} {line}" for i, line in enumerate(logs.rstrip().split("\n"))] + ) if self.status == JobStatus.COMPLETED: logs_w_linenr += "\nJOB COMPLETED" @@ -409,7 +425,7 @@ def _repr_markdown_(self) -> str: """ return as_markdown_code(md) - def wait(self, job_only=False): + def wait(self, job_only: bool = False) -> Union[Any, SyftNotReady]: # stdlib from time import sleep @@ -422,13 +438,13 @@ def wait(self, job_only=False): if self.resolved: return self.resolve - if not job_only: + if not job_only and self.result is not None: self.result.wait() print_warning = True while True: self.fetch() - if print_warning: + if print_warning and self.result is not None: result_obj = api.services.action.get( self.result.id, resolve_nested=False ) @@ -440,9 +456,10 @@ def wait(self, job_only=False): ) print_warning = False sleep(2) + # TODO: fix the mypy issue if self.resolved: - break - return self.resolve + break # type: ignore[unreachable] + return self.resolve # type: ignore[unreachable] @property def resolve(self) -> Union[Any, SyftNotReady]: @@ -518,7 +535,7 @@ def from_job( job: Job, metadata: bool = False, result: bool = False, - ): + ) -> Self: info = cls( includes_metadata=metadata, includes_result=result, @@ -539,12 +556,12 @@ def from_job( @migrate(Job, JobV2) -def downgrade_job_v3_to_v2(): +def downgrade_job_v3_to_v2() -> list[Callable]: return [drop(["job_worker_id", "user_code_id"])] @migrate(JobV2, Job) -def upgrade_job_v2_to_v3(): +def upgrade_job_v2_to_v3() -> list[Callable]: return [ make_set_default("job_worker_id", None), make_set_default("user_code_id", None), @@ -552,14 +569,14 @@ def upgrade_job_v2_to_v3(): @migrate(JobV2, JobV1) -def downgrade_job_v2_to_v1(): +def downgrade_job_v2_to_v1() -> list[Callable]: return [ drop("job_pid"), ] @migrate(JobV1, JobV2) -def upgrade_job_v1_to_v2(): +def upgrade_job_v1_to_v2() -> list[Callable]: return [make_set_default("job_pid", None)] @@ -636,7 +653,9 @@ def get_active(self, credentials: SyftVerifyKey) -> Result[SyftSuccess, str]: ) return self.query_all(credentials=credentials, qks=qks) - def get_by_worker(self, credentials: SyftVerifyKey, worker_id: str): + def get_by_worker( + self, credentials: SyftVerifyKey, worker_id: str + ) -> Result[List[Job], str]: qks = QueryKeys( qks=[PartitionKey(key="job_worker_id", type_=str).with_obj(worker_id)] ) @@ -644,7 +663,7 @@ def get_by_worker(self, credentials: SyftVerifyKey, worker_id: str): def get_by_user_code_id( self, credentials: SyftVerifyKey, user_code_id: UID - ) -> Union[List[Job], SyftError]: + ) -> Result[List[Job], str]: qks = QueryKeys( qks=[PartitionKey(key="user_code_id", type_=UID).with_obj(user_code_id)] ) diff --git a/packages/syft/src/syft/service/project/project.py b/packages/syft/src/syft/service/project/project.py index 57fc514c5d3..4f55bc444f5 100644 --- a/packages/syft/src/syft/service/project/project.py +++ b/packages/syft/src/syft/service/project/project.py @@ -55,6 +55,7 @@ from ..request.request import RequestStatus from ..response import SyftError from ..response import SyftException +from ..response import SyftInfo from ..response import SyftNotReady from ..response import SyftSuccess @@ -89,7 +90,7 @@ class ProjectEvent(SyftObject): creator_verify_key: Optional[SyftVerifyKey] signature: Optional[bytes] # dont use in signing - def __repr_syft_nested__(self): + def __repr_syft_nested__(self) -> tuple[str, str]: return ( short_qual_name(full_name_with_qualname(self)), f"{str(self.id)[:4]}...{str(self.id)[-3:]}", @@ -108,7 +109,7 @@ def rebase(self, project: Project) -> Self: prev_event = project.events[-1] if project.events else None self.project_id = project.id - if prev_event: + if prev_event and prev_event.seq_no is not None: self.prev_event_uid = prev_event.id self.prev_event_hash = prev_event.event_hash self.seq_no = prev_event.seq_no + 1 @@ -130,7 +131,8 @@ def valid(self) -> Union[SyftSuccess, SyftError]: raise Exception( f"Event hash {current_hash} does not match {self.event_hash}" ) - + if self.creator_verify_key is None: + return SyftError(message=f"{self}'s creator_verify_key is None") self.creator_verify_key.verify_key.verify(event_hash_bytes, self.signature) return SyftSuccess(message="Event signature is valid") except Exception as e: @@ -164,7 +166,11 @@ def valid_descendant( "does not match {prev_event_hash}" ) - if self.seq_no != prev_seq_no + 1: + if ( + (prev_seq_no is not None) + and (self.seq_no is not None) + and (self.seq_no != prev_seq_no + 1) + ): return SyftError( message=f"{self} seq_no: {self.seq_no} " "is not subsequent to {prev_seq_no}" @@ -178,7 +184,10 @@ def valid_descendant( if hasattr(self, "parent_event_id"): parent_event = project.event_id_hashmap[self.parent_event_id] - if type(self) not in parent_event.allowed_sub_types: + if ( + parent_event.allowed_sub_types is not None + and type(self) not in parent_event.allowed_sub_types + ): return SyftError( message=f"{self} is not a valid subevent" f"for {parent_event}" ) @@ -269,7 +278,7 @@ class ProjectRequest(ProjectEventAddObject): allowed_sub_types: List[Type] = [ProjectRequestResponse] @validator("linked_request", pre=True) - def _validate_linked_request(cls, v): + def _validate_linked_request(cls, v: Any) -> Union[Request, LinkedObject]: if isinstance(v, Request): linked_request = LinkedObject.from_obj(v, node_uid=v.node_uid) return linked_request @@ -281,7 +290,7 @@ def _validate_linked_request(cls, v): ) @property - def request(self): + def request(self) -> Request: return self.linked_request.resolve __repr_attrs__ = [ @@ -305,12 +314,14 @@ def approve(self) -> ProjectRequestResponse: return result return ProjectRequestResponse(response=True, parent_event_id=self.id) - def accept_by_depositing_result(self, result: Any, force: bool = False): + def accept_by_depositing_result( + self, result: Any, force: bool = False + ) -> Union[SyftError, SyftSuccess]: return self.request.accept_by_depositing_result(result=result, force=force) # TODO: To add deny requests, when deny functionality is added - def status(self, project: Project) -> Union[Dict, SyftError]: + def status(self, project: Project) -> Optional[Union[SyftInfo, SyftError]]: """Returns the status of the request. Args: @@ -323,7 +334,9 @@ def status(self, project: Project) -> Union[Dict, SyftError]: """ responses = project.get_children(self) if len(responses) == 0: - return "No one has responded to the request yet. Kindly recheck later 🙂" + return SyftInfo( + "No one has responded to the request yet. Kindly recheck later 🙂" + ) if len(responses) > 1: return SyftError( @@ -341,8 +354,10 @@ def status(self, project: Project) -> Union[Dict, SyftError]: print("Request Status : ", "Approved" if response.response else "Denied") + return None -def poll_creation_wizard() -> List[Any]: + +def poll_creation_wizard() -> tuple[str, list[str]]: w = textwrap.TextWrapper(initial_indent="\t", subsequent_indent="\t") welcome_msg = "Welcome to the Poll Creation Wizard 🧙‍♂️ 🪄!!!" @@ -490,8 +505,8 @@ def poll_answer_wizard(poll: ProjectMultipleChoicePoll) -> int: print(w.fill(f"Question : {poll.question}")) print() - for idx, choice in enumerate(poll.choices): - print(w.fill(f"{idx+1}. {choice}")) + for idx, choice_i in enumerate(poll.choices): + print(w.fill(f"{idx+1}. {choice_i}")) print() print("\t" + "-" * 69) @@ -501,7 +516,7 @@ def poll_answer_wizard(poll: ProjectMultipleChoicePoll) -> int: print() while True: try: - choice = int(input("\t")) + choice: int = int(input("\t")) if choice < 1 or choice > len(poll.choices): raise ValueError() except ValueError: @@ -543,7 +558,7 @@ class ProjectMultipleChoicePoll(ProjectEventAddObject): allowed_sub_types: List[Type] = [AnswerProjectPoll] @validator("choices") - def choices_min_length(cls, v): + def choices_min_length(cls, v: str) -> str: if len(v) < 1: raise ValueError("choices must have at least one item") return v @@ -553,7 +568,7 @@ def answer(self, answer: int) -> ProjectMessage: def status( self, project: Project, pretty_print: bool = True - ) -> Union[Dict, SyftError]: + ) -> Optional[Union[Dict, SyftError, SyftInfo]]: """Returns the status of the poll Args: @@ -567,7 +582,7 @@ def status( """ poll_answers = project.get_children(self) if len(poll_answers) == 0: - return "No one has answered this poll" + return SyftInfo(message="No one has answered this poll") respondents = {} for poll_answer in poll_answers[::-1]: @@ -587,6 +602,7 @@ def status( print("\nChoices:\n") for idx, choice in enumerate(self.choices): print(f"{idx+1}: {choice}") + return None else: return respondents @@ -614,9 +630,10 @@ def add_code_request_to_project( client: SyftClient, reason: Optional[str] = None, ) -> Union[SyftError, SyftSuccess]: + # TODO: fix the mypy issue if not isinstance(code, SubmitUserCode): - return SyftError( - message=f"Currently we are only support creating requests for SubmitUserCode: {type(code)}" + return SyftError( # type: ignore[unreachable] + message=f"Currently we are only support creating requests for SubmitUserCode: {type(code)}" ) if not isinstance(client, SyftClient): @@ -633,7 +650,7 @@ def add_code_request_to_project( request_event = ProjectRequest(linked_request=submitted_req) - if isinstance(project, ProjectSubmit): + if isinstance(project, ProjectSubmit) and project.bootstrap_events is not None: project.bootstrap_events.append(request_event) else: result = project.add_event(request_event) @@ -689,7 +706,7 @@ class Project(SyftObject): # store: Dict[UID, Dict[UID, SyftObject]] = {} # permissions: Dict[UID, Dict[UID, Set[str]]] = {} - def _coll_repr_(self): + def _coll_repr_(self) -> dict: return { "name": self.name, "description": self.description, @@ -873,7 +890,7 @@ def get_events( types: Optional[Union[Type, List[Type]]] = None, parent_event_ids: Optional[Union[UID, List[UID]]] = None, ids: Optional[Union[UID, List[UID]]] = None, - ): + ) -> list[ProjectEvent]: if types is None: types = [] if isinstance(types, type): @@ -917,7 +934,7 @@ def create_code_request( obj: SubmitUserCode, client: Optional[SyftClient] = None, reason: Optional[str] = None, - ): + ) -> Union[SyftSuccess, SyftError]: if client is None: leader_client = self.get_leader_client(self.user_signing_key) res = add_code_request_to_project( @@ -961,7 +978,7 @@ def messages(self) -> str: def get_last_seq_no(self) -> int: return len(self.events) - def send_message(self, message: str): + def send_message(self, message: str) -> Union[SyftSuccess, SyftError]: message_event = ProjectMessage(message=message) result = self.add_event(message_event) if isinstance(result, SyftSuccess): @@ -972,7 +989,7 @@ def reply_message( self, reply: str, message: Union[UID, ProjectMessage, ProjectThreadMessage], - ): + ) -> Union[SyftSuccess, SyftError]: if isinstance(message, UID): if message not in self.event_ids: return SyftError(message=f"Message id: {message} not found") @@ -1000,7 +1017,7 @@ def create_poll( self, question: Optional[str] = None, choices: Optional[List[str]] = None, - ): + ) -> Union[SyftSuccess, SyftError]: if ( question is None or choices is None @@ -1019,7 +1036,7 @@ def answer_poll( self, poll: Union[UID, ProjectMultipleChoicePoll], answer: Optional[int] = None, - ): + ) -> Union[SyftSuccess, SyftError]: if isinstance(poll, UID): if poll not in self.event_ids: return SyftError(message=f"Poll id: {poll} not found") @@ -1044,7 +1061,7 @@ def answer_poll( def add_request( self, request: Request, - ): + ) -> Union[SyftSuccess, SyftError]: linked_request = LinkedObject.from_obj(request, node_uid=request.node_uid) request_event = ProjectRequest(linked_request=linked_request) result = self.add_event(request_event) @@ -1058,7 +1075,7 @@ def add_request( def approve_request( self, request: Union[UID, ProjectRequest], - ): + ) -> Union[SyftError, SyftSuccess]: if isinstance(request, UID): if request not in self.event_ids: return SyftError(message=f"Request id: {request} not found") @@ -1174,7 +1191,7 @@ class ProjectSubmit(SyftObject): project_permissions: Set[str] = set() consensus_model: ConsensusModel = DemocraticConsensusModel() - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) # Preserve member SyftClients in a private variable clients @@ -1212,7 +1229,9 @@ def _repr_html_(self) -> Any: ) @validator("members", pre=True) - def verify_members(cls, val: Union[List[SyftClient], List[NodeIdentity]]): + def verify_members( + cls, val: Union[List[SyftClient], List[NodeIdentity]] + ) -> Union[List[SyftClient], List[NodeIdentity]]: # SyftClients must be logged in by the same emails clients = cls.get_syft_clients(val) if len(clients) > 0: @@ -1224,11 +1243,13 @@ def verify_members(cls, val: Union[List[SyftClient], List[NodeIdentity]]): return val @staticmethod - def get_syft_clients(vals: Union[List[SyftClient], List[NodeIdentity]]): + def get_syft_clients( + vals: Union[List[SyftClient], List[NodeIdentity]], + ) -> list[SyftClient]: return [client for client in vals if isinstance(client, SyftClient)] @staticmethod - def to_node_identity(val: Union[SyftClient, NodeIdentity]): + def to_node_identity(val: Union[SyftClient, NodeIdentity]) -> NodeIdentity: if isinstance(val, NodeIdentity): return val elif isinstance(val, SyftClient): @@ -1241,7 +1262,7 @@ def to_node_identity(val: Union[SyftClient, NodeIdentity]): def create_code_request( self, obj: SubmitUserCode, client: SyftClient, reason: Optional[str] = None - ): + ) -> Union[SyftError, SyftSuccess]: return add_code_request_to_project( project=self, code=obj, @@ -1249,7 +1270,7 @@ def create_code_request( reason=reason, ) - def start(self, return_all_projects=False) -> Project: + def start(self, return_all_projects: bool = False) -> Union[Project, list[Project]]: # Currently we are assuming that the first member is the leader # This would be changed in our future leaderless approach leader = self.clients[0] @@ -1275,7 +1296,7 @@ def start(self, return_all_projects=False) -> Project: except SyftException as exp: return SyftError(message=str(exp)) - def _pre_submit_checks(self, clients: List[SyftClient]): + def _pre_submit_checks(self, clients: List[SyftClient]) -> bool: try: # Check if the user can create projects for client in clients: @@ -1287,7 +1308,7 @@ def _pre_submit_checks(self, clients: List[SyftClient]): return True - def _exchange_routes(self, leader: SyftClient, followers: List[SyftClient]): + def _exchange_routes(self, leader: SyftClient, followers: List[SyftClient]) -> None: # Since we are implementing a leader based system # To be able to optimize exchanging routes. # We require only the leader to exchange routes with all the members @@ -1301,7 +1322,7 @@ def _exchange_routes(self, leader: SyftClient, followers: List[SyftClient]): self.leader_node_route = connection_to_route(leader.connection) - def _create_projects(self, clients: List[SyftClient]): + def _create_projects(self, clients: List[SyftClient]) -> Dict[SyftClient, Project]: projects: Dict[SyftClient, Project] = {} for client in clients: @@ -1312,7 +1333,7 @@ def _create_projects(self, clients: List[SyftClient]): return projects - def _bootstrap_events(self, leader_project: Project): + def _bootstrap_events(self, leader_project: Project) -> None: if not self.bootstrap_events: return diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index 9fc393c1b11..215dcdfe5c2 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -201,7 +201,10 @@ def close(self) -> None: @property def action_service(self) -> Callable: - return self.auth_context.node.get_service("ActionService") + if self.auth_context.node is not None: + return self.auth_context.node.get_service("ActionService") + else: + raise Exception(f"{self.auth_context} does not have a node.") def contains_unresolved_action_objects(self, arg: Any, recursion: int = 0) -> bool: """recursively check collections for unresolved action objects""" @@ -386,7 +389,8 @@ def update_consumer_state_for_worker( self, syft_worker_id: UID, consumer_state: ConsumerState ) -> None: if self.worker_stash is None: - logger.error( + # TODO: fix the mypy issue + logger.error( # type: ignore[unreachable] f"Worker stash is not defined for ZMQProducer : {self.queue_name} - {self.id}" ) return @@ -862,7 +866,7 @@ def add_producer( port: Optional[int] = None, queue_stash: Optional[QueueStash] = None, worker_stash: Optional[WorkerStash] = None, - context: AuthedServiceContext = None, + context: Optional[AuthedServiceContext] = None, ) -> ZMQProducer: """Add a producer of a queue. diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index 38f223332ea..85445c3727a 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -6,6 +6,7 @@ # third party import pydantic +from result import OkErr # relative from ...custom_worker.config import CustomWorkerConfig @@ -112,7 +113,8 @@ def launch( ) worker_image: SyftWorkerImage = result.ok() - + if context.node is None: + return SyftError(message=f"context {context}'s node is None") worker_service: WorkerService = context.node.get_service("WorkerService") worker_stash = worker_service.stash @@ -218,6 +220,8 @@ def create_pool_request( # Create a the request object with the changes and submit it # for approval. request = SubmitRequest(changes=changes) + if context.node is None: + return SyftError(message=f"context {context}'s node is None") method = context.node.get_service_method(RequestService.submit) result = method(context=context, request=request, reason=reason) @@ -315,6 +319,8 @@ def create_image_and_pool_request( # Create a request object and submit a request for approval request = SubmitRequest(changes=changes) + if context.node is None: + return SyftError(message=f"context {context}'s node is None") method = context.node.get_service_method(RequestService.submit) result = method(context=context, request=request, reason=reason) @@ -399,6 +405,8 @@ def add_workers( worker_image: SyftWorkerImage = result.ok() + if context.node is None: + return SyftError(message=f"context {context}'s node is None") worker_service: WorkerService = context.node.get_service("WorkerService") worker_stash = worker_service.stash @@ -494,15 +502,16 @@ def scale( -(current_worker_count - number) : ] - worker_stash = context.node.get_service("WorkerService").stash - # delete linkedobj workers - for worker in workers_to_delete: - delete_result = worker_stash.delete_by_uid( - credentials=context.credentials, - uid=worker.object_uid, - ) - if delete_result.is_err(): - print(f"Failed to delete worker: {worker.object_uid}") + if context.node is not None: + worker_stash = context.node.get_service("WorkerService").stash + # delete linkedobj workers + for worker in workers_to_delete: + delete_result = worker_stash.delete_by_uid( + credentials=context.credentials, + uid=worker.object_uid, + ) + if delete_result.is_err(): + print(f"Failed to delete worker: {worker.object_uid}") # update worker_pool worker_pool.max_count = number @@ -644,6 +653,9 @@ def _create_workers_in_pool( reg_username: Optional[str] = None, reg_password: Optional[str] = None, ) -> Union[Tuple[List[LinkedObject], List[ContainerSpawnStatus]], SyftError]: + if context.node is None: + return SyftError(message=f"context {context}'s node is None") + queue_port = context.node.queue_config.client_config.queue_port # Check if workers needs to be run in memory or as containers @@ -690,15 +702,16 @@ def _create_workers_in_pool( obj=worker, ) - if result.is_ok(): - worker_obj = LinkedObject.from_obj( - obj=result.ok(), - service_type=WorkerService, - node_uid=context.node.id, - ) - linked_worker_list.append(worker_obj) - else: - container_status.error = result.err() + if isinstance(result, OkErr): + if result.is_ok() and context.node is not None: + worker_obj = LinkedObject.from_obj( + obj=result.ok(), + service_type=WorkerService, + node_uid=context.node.id, + ) + linked_worker_list.append(worker_obj) + elif isinstance(result, SyftError): + container_status.error = result.err() return linked_worker_list, container_statuses From 3c63ccfb4e778d66021e7661720df5a38752853e Mon Sep 17 00:00:00 2001 From: khoaguin Date: Tue, 20 Feb 2024 11:35:27 +0700 Subject: [PATCH 14/42] [refactor] done fixing mypy issues for `service/code` --- packages/syft/src/syft/node/node.py | 12 +- .../syft/src/syft/service/code/user_code.py | 258 ++++++++++-------- .../syft/service/code/user_code_service.py | 93 +++++-- .../src/syft/service/code/user_code_stash.py | 2 +- .../syft/service/enclave/enclave_service.py | 11 +- .../service/worker/worker_pool_service.py | 3 +- .../src/syft/service/worker/worker_service.py | 11 +- 7 files changed, 238 insertions(+), 152 deletions(-) diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 45f5bbce5ad..000be6af7b7 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -1209,11 +1209,11 @@ def handle_api_call_with_unsigned_result( def add_action_to_queue( self, action, - credentials, + credentials: SyftVerifyKey, parent_job_id=None, has_execute_permissions: bool = False, worker_pool_name: Optional[str] = None, - ): + ) -> Union[Job, SyftError]: job_id = UID() task_uid = UID() worker_settings = WorkerSettings.from_node(node=self) @@ -1263,8 +1263,12 @@ def add_action_to_queue( ) def add_queueitem_to_queue( - self, queue_item, credentials, action=None, parent_job_id=None - ): + self, + queue_item: ActionQueueItem, + credentials: SyftVerifyKey, + action=None, + parent_job_id=None, + ) -> Union[Job, SyftError]: log_id = UID() role = self.get_role_for_credentials(credentials=credentials) context = AuthedServiceContext(node=self, credentials=credentials, role=role) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index a449845115d..c727431a6a2 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -137,10 +137,10 @@ class UserCodeStatusCollection(SyftHashableObject): def __init__(self, status_dict: Dict): self.status_dict = status_dict - def __repr__(self): + def __repr__(self) -> str: return str(self.status_dict) - def _repr_html_(self): + def _repr_html_(self) -> str: string = f"""