Skip to content

Commit

Permalink
[refactor] make node a non-opional attr for AuthedServiceContext
Browse files Browse the repository at this point in the history
reenable checking for mypy error `union-attr`
  • Loading branch information
khoaguin committed Feb 26, 2024
1 parent beab7ca commit c827525
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 5 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ repos:
"--install-types",
"--non-interactive",
"--config-file=tox.ini",
"--disable-error-code=union-attr", # todo: remove this line after fixing the issue context.node can be None
]

- repo: https://github.com/kynan/nbstripout
Expand Down
5 changes: 3 additions & 2 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,14 +944,15 @@ def syft_function(
output_policy_type = type(output_policy)

def decorator(f: Any) -> SubmitUserCode:
# TODO: fix the mypy issue below
res = SubmitUserCode(
code=inspect.getsource(f),
func_name=f.__name__,
signature=inspect.signature(f),
input_policy_type=input_policy_type,
input_policy_init_kwargs=input_policy.init_kwargs,
input_policy_init_kwargs=input_policy.init_kwargs, # type: ignore
output_policy_type=output_policy_type,
output_policy_init_kwargs=output_policy.init_kwargs,
output_policy_init_kwargs=output_policy.init_kwargs, # type: ignore
local_function=f,
input_kwargs=f.__code__.co_varnames[: f.__code__.co_argcount],
worker_pool_name=worker_pool_name,
Expand Down
1 change: 1 addition & 0 deletions packages/syft/src/syft/service/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class AuthedServiceContext(NodeServiceContext):
job_id: Optional[UID]
extra_kwargs: Dict = {}
has_execute_permissions: bool = False
node: AbstractNode

def capabilities(self) -> List[ServiceRoleCapability]:
return ROLE_TO_CAPABILITIES.get(self.role, [])
Expand Down
7 changes: 6 additions & 1 deletion packages/syft/src/syft/service/network/node_peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

# third party
from typing_extensions import Self
Expand Down Expand Up @@ -115,7 +116,9 @@ def from_client(cls, client: SyftClient) -> Self:
peer.node_routes.append(route)
return peer

def client_with_context(self, context: NodeServiceContext) -> SyftClient:
def client_with_context(
self, context: NodeServiceContext
) -> Union[SyftClient, SyftError]:
if len(self.node_routes) < 1:
raise Exception(f"No routes to peer: {self}")
# select the latest added route
Expand All @@ -125,6 +128,8 @@ def client_with_context(self, context: NodeServiceContext) -> SyftClient:
client_type = connection.get_client_type()
if isinstance(client_type, SyftError):
return client_type
if context.node is None:
return SyftError(message=f"context {context}'s node is None")
return client_type(connection=connection, credentials=context.node.signing_key)

def client_with_key(self, credentials: SyftSigningKey) -> SyftClient:
Expand Down
6 changes: 5 additions & 1 deletion packages/syft/src/syft/service/network/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,15 @@


class NodeRoute:
def client_with_context(self, context: NodeServiceContext) -> SyftClient:
def client_with_context(
self, context: NodeServiceContext
) -> Union[SyftClient, SyftError]:
connection = route_to_connection(route=self, context=context)
client_type = connection.get_client_type()
if isinstance(client_type, SyftError):
return client_type
if context.node is None:
return SyftError(message=f"context {context}'s node is None")
return client_type(connection=connection, credentials=context.node.signing_key)

def validate_with_context(self, context: AuthedServiceContext) -> NodePeer:
Expand Down
2 changes: 2 additions & 0 deletions packages/syft/src/syft/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def resolve_link(
if isinstance(obj, OkErr) and obj.is_ok():
obj = obj.ok()
if hasattr(obj, "node_uid"):
if context.node is None:
return SyftError(message=f"context {context}'s node is None")
obj.node_uid = context.node.id
if not isinstance(obj, OkErr):
obj = Ok(obj)
Expand Down

0 comments on commit c827525

Please sign in to comment.