Skip to content

Commit

Permalink
[refactor] fix mypy issues post merge with node-state-sync
Browse files Browse the repository at this point in the history
  • Loading branch information
khoaguin committed Mar 5, 2024
1 parent 8822a0a commit 77658a0
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 29 deletions.
11 changes: 7 additions & 4 deletions packages/syft/src/syft/client/syncing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ..service.job.job_stash import Job
from ..service.log.log import SyftLog
from ..service.sync.diff_state import NodeDiff
from ..service.sync.diff_state import ObjectDiffBatch
from ..service.sync.diff_state import ResolvedSyncState
from ..service.sync.sync_state import SyncState

Expand All @@ -20,7 +21,7 @@ def compare_states(low_state: SyncState, high_state: SyncState) -> NodeDiff:
return NodeDiff.from_sync_state(low_state=low_state, high_state=high_state)


def get_user_input_for_resolve():
def get_user_input_for_resolve() -> Optional[str]:
print(
"Do you want to keep the low state or the high state for these objects? choose 'low' or 'high'"
)
Expand All @@ -36,8 +37,8 @@ def get_user_input_for_resolve():


def resolve(
state: NodeDiff, decision: Optional[str] = None, share_private_objects=False
):
state: NodeDiff, decision: Optional[str] = None, share_private_objects: bool = False
) -> tuple[ResolvedSyncState, ResolvedSyncState]:
# TODO: only add permissions for objects where we manually give permission
# Maybe default read permission for some objects (high -> low)
resolved_state_low: ResolvedSyncState = ResolvedSyncState(alias="low")
Expand Down Expand Up @@ -81,7 +82,9 @@ def resolve(
return resolved_state_low, resolved_state_high


def get_user_input_for_batch_permissions(batch_diff, share_private_objects=False):
def get_user_input_for_batch_permissions(
batch_diff: ObjectDiffBatch, share_private_objects: bool = False
) -> None:
private_high_objects: List[Union[SyftLog, ActionObject]] = []

for diff in batch_diff.diffs:
Expand Down
15 changes: 12 additions & 3 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from ..dataset.dataset import Asset
from ..job.job_stash import Job
from ..output.output_service import ExecutionOutput
from ..output.output_service import OutputService
from ..policy.policy import CustomInputPolicy
from ..policy.policy import CustomOutputPolicy
from ..policy.policy import EmpyInputPolicy
Expand Down Expand Up @@ -637,8 +638,8 @@ def get_output_history(
return SyftError(
message="Execution denied, Please wait for the code to be approved"
)

output_service = context.node.get_service("outputservice") # type: ignore
node = cast(AbstractNode, context.node)
output_service = cast(OutputService, node.get_service("outputservice"))
return output_service.get_by_user_code_id(context, self.id)

def apply_output(
Expand All @@ -654,7 +655,9 @@ def apply_output(
)

output_ids = filter_only_uids(outputs)
output_service = context.node.get_service("outputservice") # type: ignore
context.node = cast(AbstractNode, context.node)
output_service = context.node.get_service("outputservice")
output_service = cast(OutputService, output_service)
execution_result = output_service.create(
context,
user_code_id=self.id,
Expand Down Expand Up @@ -1330,6 +1333,12 @@ def create_code_status(context: TransformContext) -> TransformContext:
# relative
from .user_code_service import UserCodeService

if context.node is None:
raise ValueError(f"{context}'s node is None")

if context.output is None:
return context

input_keys = list(context.output["input_policy_init_kwargs"].keys())
code_link = LinkedObject.from_uid(
context.output["id"],
Expand Down
5 changes: 4 additions & 1 deletion packages/syft/src/syft/service/job/job_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ..action.action_permissions import ActionPermission
from ..code.user_code import UserCode
from ..context import AuthedServiceContext
from ..log.log_service import LogService
from ..queue.queue_stash import ActionQueueItem
from ..response import SyftError
from ..response import SyftSuccess
Expand Down Expand Up @@ -219,7 +220,9 @@ def add_read_permission_job_for_code_owner(
def add_read_permission_log_for_code_owner(
self, context: AuthedServiceContext, log_id: UID, user_code: UserCode
) -> Any:
log_service = context.node.get_service("logservice") # type: ignore
context.node = cast(AbstractNode, context.node)
log_service = context.node.get_service("logservice")
log_service = cast(LogService, log_service)
return log_service.stash.add_permission(
ActionObjectPermission(
log_id, ActionPermission.READ, user_code.user_verify_key
Expand Down
25 changes: 20 additions & 5 deletions packages/syft/src/syft/service/output/output_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# stdlib
from typing import Any
from typing import ClassVar
from typing import Dict
from typing import List
from typing import Optional
Expand Down Expand Up @@ -54,8 +55,17 @@ class ExecutionOutput(SyftObject):
# Output policy is not a linked object because its saved on the usercode
output_policy_id: Optional[UID] = None

__attr_searchable__: List[str] = ["user_code_id", "created_at", "output_policy_id"]
__repr_attrs__: List[str] = ["created_at", "user_code_id", "job_id", "output_ids"]
__attr_searchable__: ClassVar[List[str]] = [
"user_code_id",
"created_at",
"output_policy_id",
]
__repr_attrs__: ClassVar[List[str]] = [
"created_at",
"user_code_id",
"job_id",
"output_ids",
]

@pydantic.root_validator(pre=True)
def add_user_code_id(cls, values: dict) -> dict:
Expand Down Expand Up @@ -108,12 +118,17 @@ def from_ids(

@property
def outputs(self) -> Optional[Union[List[ActionObject], Dict[str, ActionObject]]]:
action_service = APIRegistry.api_for(
api = APIRegistry.api_for(
node_uid=self.syft_node_location,
user_verify_key=self.syft_client_verify_key,
).services.action
# TODO error handling for action_service.get
)
if api is None:
raise ValueError(
f"Can't access the api. Please log in to {self.syft_node_location}"
)
action_service = api.services.action

# TODO: error handling for action_service.get
if isinstance(self.output_ids, dict):
return {k: action_service.get(v) for k, v in self.output_ids.items()}
elif isinstance(self.output_ids, list):
Expand Down
7 changes: 4 additions & 3 deletions packages/syft/src/syft/service/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def extract_uid(v: Any) -> UID:
return value


def filter_only_uids(results: Any) -> Union[List[UID], Dict[str, UID]]:
def filter_only_uids(results: Any) -> Union[list[UID], dict[str, UID], UID]:
if not hasattr(results, "__len__"):
results = [results]

Expand Down Expand Up @@ -390,8 +390,9 @@ def is_valid(self) -> Union[SyftSuccess, SyftError]: # type: ignore
message=f"Policy is no longer valid. count: {execution_count} >= limit: {self.limit}"
)

def _is_valid(self, context: AuthedServiceContext) -> Union[SyftSuccess, SyftError]: # type: ignore
output_service = context.node.get_service("outputservice") # type: ignore
def _is_valid(self, context: AuthedServiceContext) -> Union[SyftSuccess, SyftError]:
context.node = cast(AbstractNode, context.node)
output_service = context.node.get_service("outputservice")
output_history = output_service.get_by_output_policy_id(context, self.id)
if isinstance(output_history, SyftError):
return output_history
Expand Down
3 changes: 3 additions & 0 deletions packages/syft/src/syft/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def resolve_link(
obj = Ok(obj)
return obj

def get_all(*arg: Any, **kwargs: Any) -> Any:
pass


@serializable()
class BaseConfig(SyftBaseObject):
Expand Down
21 changes: 11 additions & 10 deletions packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ class ObjectDiff(SyftObject): # StateTuple (compare 2 objects)
@classmethod
def from_objects(
cls,
low_obj: SyftObject,
high_obj: SyftObject,
low_obj: Optional[SyftObject],
high_obj: Optional[SyftObject],
low_permissions: List[ActionObjectPermission],
high_permissions: List[ActionObjectPermission],
) -> "ObjectDiff":
if low_obj is None and high_obj is None:
raise Exception("Both objects are None")
raise ValueError("Both low and high objects are None")
obj_type = type(low_obj if low_obj is not None else high_obj)

if low_obj is None or high_obj is None:
Expand Down Expand Up @@ -218,7 +218,7 @@ def object_id(self) -> UID:
return uid

@property
def non_empty_object(self) -> SyftObject:
def non_empty_object(self) -> Optional[SyftObject]:
return self.low_obj or self.high_obj

@property
Expand Down Expand Up @@ -271,6 +271,7 @@ def diff_side_str(self, side: str) -> str:
return res

def state_str(self, side: str) -> str:
other_obj: Optional[SyftObject] = None
if side == "high":
obj = self.high_obj
other_obj = self.low_obj
Expand All @@ -286,7 +287,7 @@ def state_str(self, side: str) -> str:
if isinstance(obj, ActionObject):
return obj.__repr__()

if other_obj is None:
if other_obj is None: # type: ignore[unreachable]
attrs_str = ""
attrs = getattr(obj, "__repr_attrs__", [])
for attr in attrs:
Expand All @@ -306,7 +307,7 @@ def state_str(self, side: str) -> str:

return attr_text

def get_obj(self) -> SyftObject:
def get_obj(self) -> Optional[SyftObject]:
if self.status == "NEW":
return self.low_obj if self.low_obj is not None else self.high_obj
else:
Expand Down Expand Up @@ -416,7 +417,7 @@ class ObjectDiffBatch(SyftObject):
@property
def visual_hierarchy(self) -> Tuple[Type, dict]:
# Returns
root_obj = (
root_obj: Union[Request, UserCodeStatusCollection, ExecutionOutput, Any] = (
self.root.low_obj if self.root.low_obj is not None else self.root.high_obj
)
if isinstance(root_obj, Request):
Expand Down Expand Up @@ -459,8 +460,8 @@ def __repr__(self) -> str:
{self.hierarchy_str('high')}
"""

def _repr_markdown_(self) -> None:
return None # Turns off the _repr_markdown_ of SyftObject
def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str:
return "" # Turns off the _repr_markdown_ of SyftObject

def _get_visual_hierarchy(self, node: ObjectDiff) -> dict[ObjectDiff, dict]:
_, child_types_map = self.visual_hierarchy
Expand Down Expand Up @@ -681,7 +682,7 @@ def _build_hierarchy_helper(
return hierarchies

def objs_to_sync(self) -> List[SyftObject]:
objs = []
objs: list[SyftObject] = []
for diff in self.diffs:
if diff.status == "NEW":
objs.append(diff.get_obj())
Expand Down
8 changes: 5 additions & 3 deletions packages/syft/src/syft/service/sync/sync_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from typing import List
from typing import Set
from typing import Union
from typing import cast

# third party
from result import Result

# relative
from ...abstract_node import AbstractNode
from ...client.api import NodeIdentity
from ...node.credentials import SyftVerifyKey
from ...serde.serializable import serializable
Expand Down Expand Up @@ -166,7 +168,7 @@ def sync_items(
context, item, other_node_permissions
)
else:
item = self.transform_item(context, item)
item = self.transform_item(context, item) # type: ignore[unreachable]
res = self.set_object(context, item)

if res.is_ok():
Expand Down Expand Up @@ -212,7 +214,7 @@ def get_state(
) -> Union[SyncState, SyftError]:
new_state = SyncState()

node = context.node
node = cast(AbstractNode, context.node)

services_to_sync = [
"projectservice",
Expand All @@ -225,7 +227,7 @@ def get_state(
]

for service_name in services_to_sync:
service = node.get_service(service_name) # type: ignore
service = node.get_service(service_name)
items = service.get_all(context)
new_state.add_objects(items, api=node.root_client.api) # type: ignore

Expand Down

0 comments on commit 77658a0

Please sign in to comment.