Skip to content

Commit

Permalink
Merge pull request #8523 from khoaguin/fix-mypy-issues-syft
Browse files Browse the repository at this point in the history
[Refactor] Fixing mypy issues of `syft/`
  • Loading branch information
shubham3121 authored Mar 5, 2024
2 parents 415fb4f + 79ffa8c commit 07ce75e
Show file tree
Hide file tree
Showing 97 changed files with 1,393 additions and 832 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ repos:
- id: mypy
name: "mypy: syft"
always_run: true
files: "^packages/syft/src/syft/serde|^packages/syft/src/syft/util|^packages/syft/src/syft/service"
# files: "^packages/syft/src/syft/"
files: "^packages/syft/src/syft/"
exclude: "packages/syft/src/syft/types/dicttuple.py|^packages/syft/src/syft/service/action/action_graph.py|^packages/syft/src/syft/external/oblv/"
args: [
"--follow-imports=skip",
"--ignore-missing-imports",
Expand Down
3 changes: 3 additions & 0 deletions packages/syft/src/syft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@
logger.start()

try:
# third party
from IPython import get_ipython

get_ipython() # noqa: F821
# TODO: add back later or auto detect
# display(
Expand Down
7 changes: 6 additions & 1 deletion packages/syft/src/syft/abstract_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
from enum import Enum
from typing import Callable
from typing import Optional
from typing import TYPE_CHECKING
from typing import Union

# relative
from .serde.serializable import serializable
from .types.uid import UID

if TYPE_CHECKING:
# relative
from .service.service import AbstractService


@serializable()
class NodeType(str, Enum):
Expand Down Expand Up @@ -37,5 +42,5 @@ class AbstractNode:
node_side_type: Optional[NodeSideType]
in_memory_workers: bool

def get_service(self, path_or_func: Union[str, Callable]) -> Callable:
def get_service(self, path_or_func: Union[str, Callable]) -> "AbstractService":
raise NotImplementedError
44 changes: 25 additions & 19 deletions packages/syft/src/syft/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Tuple
from typing import Union
from typing import _GenericAlias
from typing import cast
from typing import get_args
from typing import get_origin

Expand Down Expand Up @@ -241,8 +242,8 @@ def __ipython_inspector_signature_override__(self) -> Optional[Signature]:
return self.signature

def prepare_args_and_kwargs(
self, args: List[Any], kwargs: Dict[str, Any]
) -> Union[SyftError, Tuple[List[Any], Dict[str, Any]]]:
self, args: Union[list, tuple], kwargs: dict[str, Any]
) -> Union[SyftError, tuple[tuple, dict[str, Any]]]:
# Validate and migrate args and kwargs
res = validate_callable_args_and_kwargs(args, kwargs, self.signature)
if isinstance(res, SyftError):
Expand Down Expand Up @@ -279,7 +280,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
api_call = SyftAPICall(
node_uid=self.node_uid,
path=self.path,
args=_valid_args,
args=list(_valid_args),
kwargs=_valid_kwargs,
blocking=blocking,
)
Expand All @@ -304,8 +305,8 @@ class RemoteUserCodeFunction(RemoteFunction):
api: SyftAPI

def prepare_args_and_kwargs(
self, args: List[Any], kwargs: Dict[str, Any]
) -> Union[SyftError, Tuple[List[Any], Dict[str, Any]]]:
self, args: Union[list, tuple], kwargs: Dict[str, Any]
) -> Union[SyftError, tuple[tuple, dict[str, Any]]]:
# relative
from ..service.action.action_object import convert_to_pointers

Expand Down Expand Up @@ -506,9 +507,12 @@ def _repr_html_(self) -> Any:
results = self.get_all()
return results._repr_html_()

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return NotImplementedError


def debox_signed_syftapicall_response(
signed_result: SignedSyftAPICall,
signed_result: Union[SignedSyftAPICall, Any],
) -> Union[Any, SyftError]:
if not isinstance(signed_result, SignedSyftAPICall):
return SyftError(message="The result is not signed")
Expand Down Expand Up @@ -825,16 +829,16 @@ def build_endpoint_tree(
)

@property
def services(self) -> Optional[APIModule]:
def services(self) -> APIModule:
if self.api_module is None:
self.generate_endpoints()
return self.api_module
return cast(APIModule, self.api_module)

@property
def lib(self) -> Optional[APIModule]:
def lib(self) -> APIModule:
if self.libs is None:
self.generate_endpoints()
return self.libs
return cast(APIModule, self.libs)

def has_service(self, service_name: str) -> bool:
return hasattr(self.services, service_name)
Expand Down Expand Up @@ -940,19 +944,21 @@ class NodeIdentity(Identity):
node_name: str

@staticmethod
def from_api(api: SyftAPI) -> Optional[NodeIdentity]:
def from_api(api: SyftAPI) -> NodeIdentity:
# stores the name root verify key of the domain node
if api.connection is not None:
node_metadata = api.connection.get_node_metadata(api.signing_key)
return NodeIdentity(
node_name=node_metadata.name,
node_id=api.node_uid,
verify_key=SyftVerifyKey.from_string(node_metadata.verify_key),
)
return None
if api.connection is None:
raise ValueError("{api}'s connection is None. Can't get the node identity")
node_metadata = api.connection.get_node_metadata(api.signing_key)
return NodeIdentity(
node_name=node_metadata.name,
node_id=api.node_uid,
verify_key=SyftVerifyKey.from_string(node_metadata.verify_key),
)

@classmethod
def from_change_context(cls, context: ChangeContext) -> NodeIdentity:
if context.node is None:
raise ValueError(f"{context}'s node is None")
return cls(
node_name=context.node.name,
node_id=context.node.id,
Expand Down
Loading

0 comments on commit 07ce75e

Please sign in to comment.