From 2c05d6b449a91b1db385be285abfa64c55f6b286 Mon Sep 17 00:00:00 2001 From: Ionesio Junior Date: Fri, 10 May 2024 14:33:32 -0300 Subject: [PATCH 1/5] ADD small changes to fix CustomInputPolicies to be possible --- .../syft/src/syft/service/code/user_code.py | 6 +++++- .../syft/src/syft/service/policy/policy.py | 20 +++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index cf99a8cc589..415f8006e09 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -76,6 +76,7 @@ from ..policy.policy import filter_only_uids from ..policy.policy import init_policy from ..policy.policy import load_policy_code +from ..policy.policy import partition_by_node from ..policy.policy_service import PolicyService from ..response import SyftError from ..response import SyftInfo @@ -952,10 +953,13 @@ def syft_function( if input_policy is None: input_policy = EmpyInputPolicy() + init_input_kwargs = None if isinstance(input_policy, CustomInputPolicy): input_policy_type = SubmitUserPolicy.from_obj(input_policy) + init_input_kwargs = partition_by_node(input_policy.init_kwargs) else: input_policy_type = type(input_policy) + init_input_kwargs = getattr(input_policy, "init_kwargs", {}) if output_policy is None: output_policy = SingleExecutionExactOutput() @@ -971,7 +975,7 @@ def decorator(f: Any) -> SubmitUserCode: func_name=f.__name__, signature=inspect.signature(f), input_policy_type=input_policy_type, - input_policy_init_kwargs=getattr(input_policy, "init_kwargs", {}), + input_policy_init_kwargs=init_input_kwargs, output_policy_type=output_policy_type, output_policy_init_kwargs=getattr(output_policy, "init_kwargs", {}), local_function=f, diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index 04f49bac453..c5fda8e0f3e 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -854,6 +854,26 @@ def load_policy_code(user_policy: UserPolicy) -> Any: def init_policy(user_policy: UserPolicy, init_args: dict[str, Any]) -> Any: policy_class = load_policy_code(user_policy) policy_object = policy_class() + + # Unwrapp {NodeIdentity : {x: y}} -> {x: y} + # Tech debt : For input policies, we required to have NodeIdentity args beforehand, + # therefore at this stage we had to return back to the normal args. + # Maybe there's better way to do it. + if len(init_args) and isinstance(list(init_args.keys())[0], NodeIdentity): + unwrapped_init_kwargs = init_args + if len(init_args) > 1: + raise Exception("You shoudn't have more than one Node Identity.") + # Otherwise, unwrapp it + init_args = init_args[list(init_args.keys())[0]] + init_args = {k: v for k, v in init_args.items() if k != "id"} + + # For input policies, this initializer wouldn't work properly: + # 1 - Passing {NodeIdentity: {kwargs:UIDs}} as keyword args doesn't work since keys must be strings + # 2 - Passing {kwargs: UIDs} in this initializer would not trigger the partition nodes from the + # InputPolicy initializer. + # The cleanest way to solve it is by checking if it's an Input Policy, and then, setting it manually. policy_object.__user_init__(**init_args) + if isinstance(policy_object, InputPolicy): + policy_object.init_kwargs = unwrapped_init_kwargs return policy_object From 97e4c48121a12621036b73236b979a5efba8158f Mon Sep 17 00:00:00 2001 From: Ionesio Junior Date: Fri, 10 May 2024 14:38:42 -0300 Subject: [PATCH 2/5] ADD custom input policy test at custom-policy notebook --- notebooks/api/0.8/05-custom-policy.ipynb | 218 ++++++++++++++++++++--- 1 file changed, 198 insertions(+), 20 deletions(-) diff --git a/notebooks/api/0.8/05-custom-policy.ipynb b/notebooks/api/0.8/05-custom-policy.ipynb index 75b8f43e5e7..fac27bfcbe8 100644 --- a/notebooks/api/0.8/05-custom-policy.ipynb +++ b/notebooks/api/0.8/05-custom-policy.ipynb @@ -238,13 +238,207 @@ "cell_type": "code", "execution_count": null, "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "# third party\n", + "from result import Err\n", + "from result import Ok\n", + "\n", + "# syft absolute\n", + "from syft.client.api import AuthedServiceContext\n", + "from syft.client.api import NodeIdentity\n", + "\n", + "\n", + "class CustomExactMatch(sy.CustomInputPolicy):\n", + " def __init__(self, *args: Any, **kwargs: Any) -> None:\n", + " pass\n", + "\n", + " def filter_kwargs(self, kwargs, context, code_item_id):\n", + " # stdlib\n", + "\n", + " try:\n", + " allowed_inputs = self.allowed_ids_only(\n", + " allowed_inputs=self.inputs, kwargs=kwargs, context=context\n", + " )\n", + " results = self.retrieve_from_db(\n", + " code_item_id=code_item_id,\n", + " allowed_inputs=allowed_inputs,\n", + " context=context,\n", + " )\n", + " except Exception as e:\n", + " return Err(str(e))\n", + " return results\n", + "\n", + " def retrieve_from_db(self, code_item_id, allowed_inputs, context):\n", + " # syft absolute\n", + " from syft import NodeType\n", + " from syft.service.action.action_object import TwinMode\n", + "\n", + " action_service = context.node.get_service(\"actionservice\")\n", + " code_inputs = {}\n", + "\n", + " # When we are retrieving the code from the database, we need to use the node's\n", + " # verify key as the credentials. This is because when we approve the code, we\n", + " # we allow the private data to be used only for this specific code.\n", + " # but we are not modifying the permissions of the private data\n", + "\n", + " root_context = AuthedServiceContext(\n", + " node=context.node, credentials=context.node.verify_key\n", + " )\n", + " if context.node.node_type == NodeType.DOMAIN:\n", + " for var_name, arg_id in allowed_inputs.items():\n", + " kwarg_value = action_service._get(\n", + " context=root_context,\n", + " uid=arg_id,\n", + " twin_mode=TwinMode.NONE,\n", + " has_permission=True,\n", + " )\n", + " if kwarg_value.is_err():\n", + " return Err(kwarg_value.err())\n", + " code_inputs[var_name] = kwarg_value.ok()\n", + "\n", + " elif context.node.node_type == NodeType.ENCLAVE:\n", + " dict_object = action_service.get(context=root_context, uid=code_item_id)\n", + " if dict_object.is_err():\n", + " return Err(dict_object.err())\n", + " for value in dict_object.ok().syft_action_data.values():\n", + " code_inputs.update(value)\n", + "\n", + " else:\n", + " raise Exception(\n", + " f\"Invalid Node Type for Code Submission:{context.node.node_type}\"\n", + " )\n", + " return Ok(code_inputs)\n", + "\n", + " def allowed_ids_only(\n", + " self,\n", + " allowed_inputs,\n", + " kwargs,\n", + " context,\n", + " ):\n", + " # syft absolute\n", + " from syft import NodeType\n", + " from syft import UID\n", + "\n", + " if context.node.node_type == NodeType.DOMAIN:\n", + " node_identity = NodeIdentity(\n", + " node_name=context.node.name,\n", + " node_id=context.node.id,\n", + " verify_key=context.node.signing_key.verify_key,\n", + " )\n", + " allowed_inputs = allowed_inputs.get(node_identity, {})\n", + " elif context.node.node_type == NodeType.ENCLAVE:\n", + " base_dict = {}\n", + " for key in allowed_inputs.values():\n", + " base_dict.update(key)\n", + " allowed_inputs = base_dict\n", + " else:\n", + " raise Exception(\n", + " f\"Invalid Node Type for Code Submission:{context.node.node_type}\"\n", + " )\n", + " filtered_kwargs = {}\n", + " for key in allowed_inputs.keys():\n", + " if key in kwargs:\n", + " value = kwargs[key]\n", + " uid = value\n", + " if not isinstance(uid, UID):\n", + " uid = getattr(value, \"id\", None)\n", + "\n", + " if uid != allowed_inputs[key]:\n", + " raise Exception(\n", + " f\"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}\"\n", + " )\n", + " filtered_kwargs[key] = value\n", + " return filtered_kwargs\n", + "\n", + " def _is_valid(\n", + " self,\n", + " context,\n", + " usr_input_kwargs,\n", + " code_item_id,\n", + " ):\n", + " filtered_input_kwargs = self.filter_kwargs(\n", + " kwargs=usr_input_kwargs,\n", + " context=context,\n", + " code_item_id=code_item_id,\n", + " )\n", + "\n", + " if filtered_input_kwargs.is_err():\n", + " return filtered_input_kwargs\n", + "\n", + " filtered_input_kwargs = filtered_input_kwargs.ok()\n", + "\n", + " expected_input_kwargs = set()\n", + " for _inp_kwargs in self.inputs.values():\n", + " for k in _inp_kwargs.keys():\n", + " if k not in usr_input_kwargs:\n", + " return Err(f\"Function missing required keyword argument: '{k}'\")\n", + " expected_input_kwargs.update(_inp_kwargs.keys())\n", + "\n", + " permitted_input_kwargs = list(filtered_input_kwargs.keys())\n", + " not_approved_kwargs = set(expected_input_kwargs) - set(permitted_input_kwargs)\n", + " if len(not_approved_kwargs) > 0:\n", + " return Err(\n", + " f\"Input arguments: {not_approved_kwargs} to the function are not approved yet.\"\n", + " )\n", + " return Ok(True)\n", + "\n", + "\n", + "def allowed_ids_only(\n", + " self,\n", + " allowed_inputs,\n", + " kwargs,\n", + " context,\n", + "):\n", + " # syft absolute\n", + " from syft import NodeType\n", + " from syft import UID\n", + " from syft.client.api import NodeIdentity\n", + "\n", + " if context.node.node_type == NodeType.DOMAIN:\n", + " node_identity = NodeIdentity(\n", + " node_name=context.node.name,\n", + " node_id=context.node.id,\n", + " verify_key=context.node.signing_key.verify_key,\n", + " )\n", + " allowed_inputs = allowed_inputs.get(node_identity, {})\n", + " elif context.node.node_type == NodeType.ENCLAVE:\n", + " base_dict = {}\n", + " for key in allowed_inputs.values():\n", + " base_dict.update(key)\n", + " allowed_inputs = base_dict\n", + " else:\n", + " raise Exception(\n", + " f\"Invalid Node Type for Code Submission:{context.node.node_type}\"\n", + " )\n", + " filtered_kwargs = {}\n", + " for key in allowed_inputs.keys():\n", + " if key in kwargs:\n", + " value = kwargs[key]\n", + " uid = value\n", + " if not isinstance(uid, UID):\n", + " uid = getattr(value, \"id\", None)\n", + "\n", + " if uid != allowed_inputs[key]:\n", + " raise Exception(\n", + " f\"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}\"\n", + " )\n", + " filtered_kwargs[key] = value\n", + " return filtered_kwargs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", "metadata": { "tags": [] }, "outputs": [], "source": [ "@sy.syft_function(\n", - " input_policy=sy.ExactMatch(x=x_pointer),\n", + " input_policy=CustomExactMatch(x=x_pointer),\n", " output_policy=RepeatedCallPolicy(n_calls=10, downloadable_output_args=[\"y\"]),\n", ")\n", "def func(x):\n", @@ -254,7 +448,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "18", "metadata": { "tags": [] }, @@ -267,21 +461,13 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "19", "metadata": {}, "outputs": [], "source": [ "request_id = request.id" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "19", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, @@ -433,14 +619,6 @@ "if node.node_type.value == \"python\":\n", " node.land()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -459,7 +637,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.0rc1" }, "toc": { "base_numbering": 1, From e1fc6e34fcea9d67037de00992626fbf1be7ca21 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 22 May 2024 11:12:37 +0200 Subject: [PATCH 3/5] fix filter --- .../syft/src/syft/service/sync/diff_state.py | 175 ++++++++++++++++-- 1 file changed, 163 insertions(+), 12 deletions(-) diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index 014e33f5bc8..82b1d6021a8 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -1,12 +1,17 @@ # stdlib +from collections.abc import Callable from collections.abc import Iterable +from dataclasses import dataclass +import enum import html +import operator import textwrap from typing import Any from typing import ClassVar from typing import Literal # third party +from loguru import logger import pandas as pd from pydantic import model_validator from rich import box @@ -50,6 +55,7 @@ from ..request.request import Request from ..response import SyftError from ..response import SyftSuccess +from ..user.user import UserView from .sync_state import SyncState sketchy_tab = "‎ " * 4 @@ -944,19 +950,32 @@ def visual_root(self) -> ObjectDiff: @property def user_code_high(self) -> UserCode | None: """return the user code of the high side of this batch, if it exists""" - user_codes_high: list[UserCode] = [ - diff.high_obj + user_code_diff = self.user_code_diff + if user_code_diff is not None and isinstance(user_code_diff.high_obj, UserCode): + return user_code_diff.high_obj + return None + + @property + def user_code_diff(self) -> ObjectDiff | None: + """return the main user code diff of the high side of this batch, if it exists""" + user_code_diffs: list[ObjectDiff] = [ + diff for diff in self.get_dependencies(include_roots=True) - if isinstance(diff.high_obj, UserCode) + if issubclass(diff.obj_type, UserCode) ] - if len(user_codes_high) == 0: - user_code_high = None + if len(user_code_diffs) == 0: + return None else: - # NOTE we can always assume the first usercode is - # not a nested code, because diffs are sorted in depth-first order - user_code_high = user_codes_high[0] - return user_code_high + # main usercode is always the first, batches are sorted in depth-first order + return user_code_diffs[0] + + @property + def user(self) -> UserView | SyftError: + user_code_diff = self.user_code_diff + if user_code_diff is not None and isinstance(user_code_diff.low_obj, UserCode): + return user_code_diff.low_obj.user + return SyftError(message="No user found") def get_visual_hierarchy(self) -> dict[ObjectDiff, dict]: visual_root = self.visual_root @@ -1022,6 +1041,70 @@ def stage_change(self) -> None: other_batch.decision = None +class FilterProperty(enum.Enum): + USER = enum.auto() + TYPE = enum.auto() + STATUS = enum.auto() + IGNORED = enum.auto() + + def from_batch(self, batch: ObjectDiffBatch) -> Any: + if self == FilterProperty.USER: + user = batch.user + if isinstance(user, UserView): + return user.email + return None + elif self == FilterProperty.BATCH_TYPE: + return batch.root_diff.obj_type + elif self == FilterProperty.STATUS: + return batch.status + elif self == FilterProperty.IGNORED: + return batch.is_ignored + else: + raise ValueError(f"Invalid property: {property}") + + +@dataclass +class NodeDiffFilter: + """ + Filter to apply to a NodeDiff object to determine if it should be included in a batch. + + Tests for `property op value` , where + property: FilterProperty - property to filter on + value: Any - value to compare against + op: callable[[Any, Any], bool] - comparison operator. Default is `operator.eq` + + If the comparison fails, the batch is included in the result. + """ + + filter_property: FilterProperty + filter_value: Any + op: Callable[[Any, Any], bool] = operator.eq + + def __call__(self, batch: ObjectDiffBatch) -> bool: + try: + p = self.filter_property.from_batch(batch) + if self.op == operator.contains: + # Contains check has reversed arg order: check if p in self.filter_value + return p in self.filter_value + else: + return self.op(p, self.filter_value) + except Exception as e: + logger.debug(f"Error filtering batch {batch} with {self}: {e}") + return True + + def __hash__(self) -> int: + return hash(self.filter_property) + hash(self.filter_value) + hash(self.op) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, NodeDiffFilter): + return False + return ( + self.filter_property == other.filter_property + and self.filter_value == other.filter_value + and self.op == other.op + ) + + class NodeDiff(SyftObject): __canonical_name__ = "NodeDiff" __version__ = SYFT_OBJECT_VERSION_2 @@ -1037,6 +1120,7 @@ class NodeDiff(SyftObject): low_state: SyncState high_state: SyncState direction: SyncDirection | None + filters: list[NodeDiffFilter] = [] include_ignored: bool = False @@ -1074,6 +1158,7 @@ def from_sync_state( high_state: SyncState, direction: SyncDirection, include_ignored: bool = False, + include_same: bool = False, _include_node_status: bool = False, ) -> "NodeDiff": obj_uid_to_diff = {} @@ -1126,10 +1211,15 @@ def from_sync_state( previously_ignored_batches = low_state.ignored_batches NodeDiff.apply_previous_ignore_state(all_batches, previously_ignored_batches) + filters = [] if not include_ignored: - batches = [b for b in all_batches if not b.is_ignored] - else: - batches = all_batches + filters.append(NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne)) + if not include_same: + filters.append(NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne)) + + batches = all_batches + for f in filters: + batches = [b for b in batches if f(b)] return cls( low_node_uid=low_state.node_uid, @@ -1143,6 +1233,7 @@ def from_sync_state( low_state=low_state, high_state=high_state, direction=direction, + filters=filters, ) @staticmethod @@ -1317,6 +1408,66 @@ def hierarchies( def is_same(self) -> bool: return all(object_diff.status == "SAME" for object_diff in self.diffs) + def _apply_filters(self, filters: list[NodeDiffFilter]) -> Self: + """ + Apply filters to the NodeDiff object and return a new NodeDiff object + """ + batches = self.all_batches + for filter in filters: + batches = [b for b in batches if filter(b)] + return NodeDiff( + low_node_uid=self.low_node_uid, + high_node_uid=self.high_node_uid, + user_verify_key_low=self.user_verify_key_low, + user_verify_key_high=self.user_verify_key_high, + obj_uid_to_diff=self.obj_uid_to_diff, + obj_dependencies=self.obj_dependencies, + batches=batches, + all_batches=self.all_batches, + low_state=self.low_state, + high_state=self.high_state, + direction=self.direction, + filters=filters, + ) + + def reset_filters( + self, + include_ignored: bool = False, + include_same: bool = False, + ) -> Self: + filters = [] + if not include_ignored: + filters.append(NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne)) + if not include_same: + filters.append(NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne)) + return self._apply_filters(filters) + + def filter( + self, + user: str | None = None, + obj_type: type | None = None, + ) -> Self: + current_filters = self.filters + new_filters = [] + if user is not None: + new_filters.append(NodeDiffFilter(FilterProperty.USER, user)) + if obj_type is not None: + new_filters.append(NodeDiffFilter(FilterProperty.TYPE, obj_type)) + + if len(new_filters) == 0: + return self + + new_filter_properties = {f.filter_property for f in new_filters} + # Only add filters that are not in the new filters + # - remove duplicate filters + # - overwrite filters with the same property but different value + # (example: cannot filter on 2 different users) + for current_filter in current_filters: + if current_filter.filter_property not in new_filter_properties: + new_filters.append(current_filter) + + return self._apply_filters(new_filters) + class SyncInstruction(SyftObject): __canonical_name__ = "SyncDecision" From fafa4b9b9a26243865c5800f0ab8b3b03b2e3de0 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 22 May 2024 13:19:28 +0200 Subject: [PATCH 4/5] move filters to compare_client method --- packages/syft/src/syft/client/syncing.py | 28 +++- .../syft/src/syft/service/sync/diff_state.py | 147 +++++++++--------- 2 files changed, 96 insertions(+), 79 deletions(-) diff --git a/packages/syft/src/syft/client/syncing.py b/packages/syft/src/syft/client/syncing.py index 371b77df22a..7185cb5316e 100644 --- a/packages/syft/src/syft/client/syncing.py +++ b/packages/syft/src/syft/client/syncing.py @@ -18,7 +18,12 @@ def compare_states( - from_state: SyncState, to_state: SyncState, include_ignored: bool = False + from_state: SyncState, + to_state: SyncState, + include_ignored: bool = False, + include_same: bool = False, + filter_by_email: str | None = None, + filter_by_type: str | type | None = None, ) -> NodeDiff: # NodeDiff if ( @@ -42,11 +47,28 @@ def compare_states( high_state=high_state, direction=direction, include_ignored=include_ignored, + include_same=include_same, + filter_by_email=filter_by_email, + filter_by_type=filter_by_type, ) -def compare_clients(low_client: SyftClient, high_client: SyftClient) -> NodeDiff: - return compare_states(low_client.get_sync_state(), high_client.get_sync_state()) +def compare_clients( + from_client: SyftClient, + to_client: SyftClient, + include_ignored: bool = False, + include_same: bool = False, + filter_by_email: str | None = None, + filter_by_type: type | None = None, +) -> NodeDiff: + return compare_states( + from_client.get_sync_state(), + to_client.get_sync_state(), + include_ignored=include_ignored, + include_same=include_same, + filter_by_email=filter_by_email, + filter_by_type=filter_by_type, + ) def get_user_input_for_resolve() -> SyncDecision: diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index 4c24487edb4..f40e86306c3 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -1054,10 +1054,10 @@ def from_batch(self, batch: ObjectDiffBatch) -> Any: if isinstance(user, UserView): return user.email return None - elif self == FilterProperty.BATCH_TYPE: - return batch.root_diff.obj_type + elif self == FilterProperty.TYPE: + return batch.root_diff.obj_type.__name__.lower() elif self == FilterProperty.STATUS: - return batch.status + return batch.status.lower() elif self == FilterProperty.IGNORED: return batch.is_ignored else: @@ -1069,7 +1069,7 @@ class NodeDiffFilter: """ Filter to apply to a NodeDiff object to determine if it should be included in a batch. - Tests for `property op value` , where + Checks for `property op value` , where property: FilterProperty - property to filter on value: Any - value to compare against op: callable[[Any, Any], bool] - comparison operator. Default is `operator.eq` @@ -1082,28 +1082,21 @@ class NodeDiffFilter: op: Callable[[Any, Any], bool] = operator.eq def __call__(self, batch: ObjectDiffBatch) -> bool: + filter_value = self.filter_value + if isinstance(filter_value, str): + filter_value = filter_value.lower() + try: p = self.filter_property.from_batch(batch) if self.op == operator.contains: # Contains check has reversed arg order: check if p in self.filter_value - return p in self.filter_value + return p in filter_value else: - return self.op(p, self.filter_value) + return self.op(p, filter_value) except Exception as e: + # By default, exclude the batch if there is an error logger.debug(f"Error filtering batch {batch} with {self}: {e}") - return True - - def __hash__(self) -> int: - return hash(self.filter_property) + hash(self.filter_value) + hash(self.op) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, NodeDiffFilter): return False - return ( - self.filter_property == other.filter_property - and self.filter_value == other.filter_value - and self.op == other.op - ) class NodeDiff(SyftObject): @@ -1160,6 +1153,8 @@ def from_sync_state( direction: SyncDirection, include_ignored: bool = False, include_same: bool = False, + filter_by_email: str | None = None, + filter_by_type: type | None = None, _include_node_status: bool = False, ) -> "NodeDiff": obj_uid_to_diff = {} @@ -1212,31 +1207,31 @@ def from_sync_state( previously_ignored_batches = low_state.ignored_batches NodeDiff.apply_previous_ignore_state(all_batches, previously_ignored_batches) - filters = [] - if not include_ignored: - filters.append(NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne)) - if not include_same: - filters.append(NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne)) - - batches = all_batches - for f in filters: - batches = [b for b in batches if f(b)] - - return cls( + res = cls( low_node_uid=low_state.node_uid, high_node_uid=high_state.node_uid, user_verify_key_low=low_state.syft_client_verify_key, user_verify_key_high=high_state.syft_client_verify_key, obj_uid_to_diff=obj_uid_to_diff, obj_dependencies=obj_dependencies, - batches=batches, + batches=all_batches, all_batches=all_batches, low_state=low_state, high_state=high_state, direction=direction, - filters=filters, + filters=[], + ) + + res._filter( + user_email=filter_by_email, + obj_type=filter_by_type, + include_ignored=include_ignored, + include_same=include_same, + inplace=True, ) + return res + @staticmethod def apply_previous_ignore_state( batches: list[ObjectDiffBatch], previously_ignored_batches: dict[UID, int] @@ -1414,65 +1409,65 @@ def hierarchies( def is_same(self) -> bool: return all(object_diff.status == "SAME" for object_diff in self.diffs) - def _apply_filters(self, filters: list[NodeDiffFilter]) -> Self: + def _apply_filters( + self, filters: list[NodeDiffFilter], inplace: bool = True + ) -> Self: """ Apply filters to the NodeDiff object and return a new NodeDiff object """ batches = self.all_batches for filter in filters: batches = [b for b in batches if filter(b)] - return NodeDiff( - low_node_uid=self.low_node_uid, - high_node_uid=self.high_node_uid, - user_verify_key_low=self.user_verify_key_low, - user_verify_key_high=self.user_verify_key_high, - obj_uid_to_diff=self.obj_uid_to_diff, - obj_dependencies=self.obj_dependencies, - batches=batches, - all_batches=self.all_batches, - low_state=self.low_state, - high_state=self.high_state, - direction=self.direction, - filters=filters, - ) - def reset_filters( + if inplace: + self.filters = filters + self.batches = batches + return self + else: + return NodeDiff( + low_node_uid=self.low_node_uid, + high_node_uid=self.high_node_uid, + user_verify_key_low=self.user_verify_key_low, + user_verify_key_high=self.user_verify_key_high, + obj_uid_to_diff=self.obj_uid_to_diff, + obj_dependencies=self.obj_dependencies, + batches=batches, + all_batches=self.all_batches, + low_state=self.low_state, + high_state=self.high_state, + direction=self.direction, + filters=filters, + ) + + def _filter( self, + user_email: str | None = None, + obj_type: str | type | None = None, include_ignored: bool = False, include_same: bool = False, + inplace: bool = True, ) -> Self: - filters = [] - if not include_ignored: - filters.append(NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne)) - if not include_same: - filters.append(NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne)) - return self._apply_filters(filters) - - def filter( - self, - user: str | None = None, - obj_type: type | None = None, - ) -> Self: - current_filters = self.filters new_filters = [] - if user is not None: - new_filters.append(NodeDiffFilter(FilterProperty.USER, user)) + if user_email is not None: + new_filters.append( + NodeDiffFilter(FilterProperty.USER, user_email, operator.eq) + ) if obj_type is not None: - new_filters.append(NodeDiffFilter(FilterProperty.TYPE, obj_type)) - - if len(new_filters) == 0: - return self - - new_filter_properties = {f.filter_property for f in new_filters} - # Only add filters that are not in the new filters - # - remove duplicate filters - # - overwrite filters with the same property but different value - # (example: cannot filter on 2 different users) - for current_filter in current_filters: - if current_filter.filter_property not in new_filter_properties: - new_filters.append(current_filter) + if isinstance(obj_type, type): + obj_type = obj_type.__name__ + new_filters.append( + NodeDiffFilter(FilterProperty.TYPE, obj_type, operator.eq) + ) + if not include_ignored: + new_filters.append( + NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne) + ) + if not include_same: + new_filters.append( + NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne) + ) - return self._apply_filters(new_filters) + return self._apply_filters(new_filters, inplace=inplace) class SyncInstruction(SyftObject): From eeb602c11da904d8c33df93199879725590907c8 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 23 May 2024 10:40:02 +0200 Subject: [PATCH 5/5] fix comment --- packages/syft/src/syft/service/sync/diff_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index f40e86306c3..3cb360eafa4 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -1074,7 +1074,7 @@ class NodeDiffFilter: value: Any - value to compare against op: callable[[Any, Any], bool] - comparison operator. Default is `operator.eq` - If the comparison fails, the batch is included in the result. + If the comparison fails, the batch is excluded. """ filter_property: FilterProperty