From e1fc6e34fcea9d67037de00992626fbf1be7ca21 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 22 May 2024 11:12:37 +0200 Subject: [PATCH] fix filter --- .../syft/src/syft/service/sync/diff_state.py | 175 ++++++++++++++++-- 1 file changed, 163 insertions(+), 12 deletions(-) diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index 014e33f5bc8..82b1d6021a8 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -1,12 +1,17 @@ # stdlib +from collections.abc import Callable from collections.abc import Iterable +from dataclasses import dataclass +import enum import html +import operator import textwrap from typing import Any from typing import ClassVar from typing import Literal # third party +from loguru import logger import pandas as pd from pydantic import model_validator from rich import box @@ -50,6 +55,7 @@ from ..request.request import Request from ..response import SyftError from ..response import SyftSuccess +from ..user.user import UserView from .sync_state import SyncState sketchy_tab = "‎ " * 4 @@ -944,19 +950,32 @@ def visual_root(self) -> ObjectDiff: @property def user_code_high(self) -> UserCode | None: """return the user code of the high side of this batch, if it exists""" - user_codes_high: list[UserCode] = [ - diff.high_obj + user_code_diff = self.user_code_diff + if user_code_diff is not None and isinstance(user_code_diff.high_obj, UserCode): + return user_code_diff.high_obj + return None + + @property + def user_code_diff(self) -> ObjectDiff | None: + """return the main user code diff of the high side of this batch, if it exists""" + user_code_diffs: list[ObjectDiff] = [ + diff for diff in self.get_dependencies(include_roots=True) - if isinstance(diff.high_obj, UserCode) + if issubclass(diff.obj_type, UserCode) ] - if len(user_codes_high) == 0: - user_code_high = None + if len(user_code_diffs) == 0: + return None else: - # NOTE we can always assume the first usercode is - # not a nested code, because diffs are sorted in depth-first order - user_code_high = user_codes_high[0] - return user_code_high + # main usercode is always the first, batches are sorted in depth-first order + return user_code_diffs[0] + + @property + def user(self) -> UserView | SyftError: + user_code_diff = self.user_code_diff + if user_code_diff is not None and isinstance(user_code_diff.low_obj, UserCode): + return user_code_diff.low_obj.user + return SyftError(message="No user found") def get_visual_hierarchy(self) -> dict[ObjectDiff, dict]: visual_root = self.visual_root @@ -1022,6 +1041,70 @@ def stage_change(self) -> None: other_batch.decision = None +class FilterProperty(enum.Enum): + USER = enum.auto() + TYPE = enum.auto() + STATUS = enum.auto() + IGNORED = enum.auto() + + def from_batch(self, batch: ObjectDiffBatch) -> Any: + if self == FilterProperty.USER: + user = batch.user + if isinstance(user, UserView): + return user.email + return None + elif self == FilterProperty.BATCH_TYPE: + return batch.root_diff.obj_type + elif self == FilterProperty.STATUS: + return batch.status + elif self == FilterProperty.IGNORED: + return batch.is_ignored + else: + raise ValueError(f"Invalid property: {property}") + + +@dataclass +class NodeDiffFilter: + """ + Filter to apply to a NodeDiff object to determine if it should be included in a batch. + + Tests for `property op value` , where + property: FilterProperty - property to filter on + value: Any - value to compare against + op: callable[[Any, Any], bool] - comparison operator. Default is `operator.eq` + + If the comparison fails, the batch is included in the result. + """ + + filter_property: FilterProperty + filter_value: Any + op: Callable[[Any, Any], bool] = operator.eq + + def __call__(self, batch: ObjectDiffBatch) -> bool: + try: + p = self.filter_property.from_batch(batch) + if self.op == operator.contains: + # Contains check has reversed arg order: check if p in self.filter_value + return p in self.filter_value + else: + return self.op(p, self.filter_value) + except Exception as e: + logger.debug(f"Error filtering batch {batch} with {self}: {e}") + return True + + def __hash__(self) -> int: + return hash(self.filter_property) + hash(self.filter_value) + hash(self.op) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, NodeDiffFilter): + return False + return ( + self.filter_property == other.filter_property + and self.filter_value == other.filter_value + and self.op == other.op + ) + + class NodeDiff(SyftObject): __canonical_name__ = "NodeDiff" __version__ = SYFT_OBJECT_VERSION_2 @@ -1037,6 +1120,7 @@ class NodeDiff(SyftObject): low_state: SyncState high_state: SyncState direction: SyncDirection | None + filters: list[NodeDiffFilter] = [] include_ignored: bool = False @@ -1074,6 +1158,7 @@ def from_sync_state( high_state: SyncState, direction: SyncDirection, include_ignored: bool = False, + include_same: bool = False, _include_node_status: bool = False, ) -> "NodeDiff": obj_uid_to_diff = {} @@ -1126,10 +1211,15 @@ def from_sync_state( previously_ignored_batches = low_state.ignored_batches NodeDiff.apply_previous_ignore_state(all_batches, previously_ignored_batches) + filters = [] if not include_ignored: - batches = [b for b in all_batches if not b.is_ignored] - else: - batches = all_batches + filters.append(NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne)) + if not include_same: + filters.append(NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne)) + + batches = all_batches + for f in filters: + batches = [b for b in batches if f(b)] return cls( low_node_uid=low_state.node_uid, @@ -1143,6 +1233,7 @@ def from_sync_state( low_state=low_state, high_state=high_state, direction=direction, + filters=filters, ) @staticmethod @@ -1317,6 +1408,66 @@ def hierarchies( def is_same(self) -> bool: return all(object_diff.status == "SAME" for object_diff in self.diffs) + def _apply_filters(self, filters: list[NodeDiffFilter]) -> Self: + """ + Apply filters to the NodeDiff object and return a new NodeDiff object + """ + batches = self.all_batches + for filter in filters: + batches = [b for b in batches if filter(b)] + return NodeDiff( + low_node_uid=self.low_node_uid, + high_node_uid=self.high_node_uid, + user_verify_key_low=self.user_verify_key_low, + user_verify_key_high=self.user_verify_key_high, + obj_uid_to_diff=self.obj_uid_to_diff, + obj_dependencies=self.obj_dependencies, + batches=batches, + all_batches=self.all_batches, + low_state=self.low_state, + high_state=self.high_state, + direction=self.direction, + filters=filters, + ) + + def reset_filters( + self, + include_ignored: bool = False, + include_same: bool = False, + ) -> Self: + filters = [] + if not include_ignored: + filters.append(NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne)) + if not include_same: + filters.append(NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne)) + return self._apply_filters(filters) + + def filter( + self, + user: str | None = None, + obj_type: type | None = None, + ) -> Self: + current_filters = self.filters + new_filters = [] + if user is not None: + new_filters.append(NodeDiffFilter(FilterProperty.USER, user)) + if obj_type is not None: + new_filters.append(NodeDiffFilter(FilterProperty.TYPE, obj_type)) + + if len(new_filters) == 0: + return self + + new_filter_properties = {f.filter_property for f in new_filters} + # Only add filters that are not in the new filters + # - remove duplicate filters + # - overwrite filters with the same property but different value + # (example: cannot filter on 2 different users) + for current_filter in current_filters: + if current_filter.filter_property not in new_filter_properties: + new_filters.append(current_filter) + + return self._apply_filters(new_filters) + class SyncInstruction(SyftObject): __canonical_name__ = "SyncDecision"