Skip to content

Commit

Permalink
fix filter
Browse files Browse the repository at this point in the history
  • Loading branch information
eelcovdw committed May 22, 2024
1 parent 5c739a5 commit e1fc6e3
Showing 1 changed file with 163 additions and 12 deletions.
175 changes: 163 additions & 12 deletions packages/syft/src/syft/service/sync/diff_state.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# stdlib
from collections.abc import Callable
from collections.abc import Iterable
from dataclasses import dataclass
import enum
import html
import operator
import textwrap
from typing import Any
from typing import ClassVar
from typing import Literal

# third party
from loguru import logger
import pandas as pd
from pydantic import model_validator
from rich import box
Expand Down Expand Up @@ -50,6 +55,7 @@
from ..request.request import Request
from ..response import SyftError
from ..response import SyftSuccess
from ..user.user import UserView
from .sync_state import SyncState

sketchy_tab = "‎ " * 4
Expand Down Expand Up @@ -944,19 +950,32 @@ def visual_root(self) -> ObjectDiff:
@property
def user_code_high(self) -> UserCode | None:
"""return the user code of the high side of this batch, if it exists"""
user_codes_high: list[UserCode] = [
diff.high_obj
user_code_diff = self.user_code_diff
if user_code_diff is not None and isinstance(user_code_diff.high_obj, UserCode):
return user_code_diff.high_obj
return None

@property
def user_code_diff(self) -> ObjectDiff | None:
"""return the main user code diff of the high side of this batch, if it exists"""
user_code_diffs: list[ObjectDiff] = [
diff
for diff in self.get_dependencies(include_roots=True)
if isinstance(diff.high_obj, UserCode)
if issubclass(diff.obj_type, UserCode)
]

if len(user_codes_high) == 0:
user_code_high = None
if len(user_code_diffs) == 0:
return None
else:
# NOTE we can always assume the first usercode is
# not a nested code, because diffs are sorted in depth-first order
user_code_high = user_codes_high[0]
return user_code_high
# main usercode is always the first, batches are sorted in depth-first order
return user_code_diffs[0]

@property
def user(self) -> UserView | SyftError:
user_code_diff = self.user_code_diff
if user_code_diff is not None and isinstance(user_code_diff.low_obj, UserCode):
return user_code_diff.low_obj.user
return SyftError(message="No user found")

def get_visual_hierarchy(self) -> dict[ObjectDiff, dict]:
visual_root = self.visual_root
Expand Down Expand Up @@ -1022,6 +1041,70 @@ def stage_change(self) -> None:
other_batch.decision = None


class FilterProperty(enum.Enum):
USER = enum.auto()
TYPE = enum.auto()
STATUS = enum.auto()
IGNORED = enum.auto()

def from_batch(self, batch: ObjectDiffBatch) -> Any:
if self == FilterProperty.USER:
user = batch.user
if isinstance(user, UserView):
return user.email
return None
elif self == FilterProperty.BATCH_TYPE:
return batch.root_diff.obj_type
elif self == FilterProperty.STATUS:
return batch.status
elif self == FilterProperty.IGNORED:
return batch.is_ignored
else:
raise ValueError(f"Invalid property: {property}")


@dataclass
class NodeDiffFilter:
"""
Filter to apply to a NodeDiff object to determine if it should be included in a batch.
Tests for `property op value` , where
property: FilterProperty - property to filter on
value: Any - value to compare against
op: callable[[Any, Any], bool] - comparison operator. Default is `operator.eq`
If the comparison fails, the batch is included in the result.
"""

filter_property: FilterProperty
filter_value: Any
op: Callable[[Any, Any], bool] = operator.eq

def __call__(self, batch: ObjectDiffBatch) -> bool:
try:
p = self.filter_property.from_batch(batch)
if self.op == operator.contains:
# Contains check has reversed arg order: check if p in self.filter_value
return p in self.filter_value
else:
return self.op(p, self.filter_value)
except Exception as e:
logger.debug(f"Error filtering batch {batch} with {self}: {e}")
return True

def __hash__(self) -> int:
return hash(self.filter_property) + hash(self.filter_value) + hash(self.op)

def __eq__(self, other: Any) -> bool:
if not isinstance(other, NodeDiffFilter):
return False
return (
self.filter_property == other.filter_property
and self.filter_value == other.filter_value
and self.op == other.op
)


class NodeDiff(SyftObject):
__canonical_name__ = "NodeDiff"
__version__ = SYFT_OBJECT_VERSION_2
Expand All @@ -1037,6 +1120,7 @@ class NodeDiff(SyftObject):
low_state: SyncState
high_state: SyncState
direction: SyncDirection | None
filters: list[NodeDiffFilter] = []

include_ignored: bool = False

Expand Down Expand Up @@ -1074,6 +1158,7 @@ def from_sync_state(
high_state: SyncState,
direction: SyncDirection,
include_ignored: bool = False,
include_same: bool = False,
_include_node_status: bool = False,
) -> "NodeDiff":
obj_uid_to_diff = {}
Expand Down Expand Up @@ -1126,10 +1211,15 @@ def from_sync_state(
previously_ignored_batches = low_state.ignored_batches
NodeDiff.apply_previous_ignore_state(all_batches, previously_ignored_batches)

filters = []
if not include_ignored:
batches = [b for b in all_batches if not b.is_ignored]
else:
batches = all_batches
filters.append(NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne))
if not include_same:
filters.append(NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne))

batches = all_batches
for f in filters:
batches = [b for b in batches if f(b)]

return cls(
low_node_uid=low_state.node_uid,
Expand All @@ -1143,6 +1233,7 @@ def from_sync_state(
low_state=low_state,
high_state=high_state,
direction=direction,
filters=filters,
)

@staticmethod
Expand Down Expand Up @@ -1317,6 +1408,66 @@ def hierarchies(
def is_same(self) -> bool:
return all(object_diff.status == "SAME" for object_diff in self.diffs)

def _apply_filters(self, filters: list[NodeDiffFilter]) -> Self:
"""
Apply filters to the NodeDiff object and return a new NodeDiff object
"""
batches = self.all_batches
for filter in filters:
batches = [b for b in batches if filter(b)]
return NodeDiff(
low_node_uid=self.low_node_uid,
high_node_uid=self.high_node_uid,
user_verify_key_low=self.user_verify_key_low,
user_verify_key_high=self.user_verify_key_high,
obj_uid_to_diff=self.obj_uid_to_diff,
obj_dependencies=self.obj_dependencies,
batches=batches,
all_batches=self.all_batches,
low_state=self.low_state,
high_state=self.high_state,
direction=self.direction,
filters=filters,
)

def reset_filters(
self,
include_ignored: bool = False,
include_same: bool = False,
) -> Self:
filters = []
if not include_ignored:
filters.append(NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne))
if not include_same:
filters.append(NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne))
return self._apply_filters(filters)

def filter(
self,
user: str | None = None,
obj_type: type | None = None,
) -> Self:
current_filters = self.filters
new_filters = []
if user is not None:
new_filters.append(NodeDiffFilter(FilterProperty.USER, user))
if obj_type is not None:
new_filters.append(NodeDiffFilter(FilterProperty.TYPE, obj_type))

if len(new_filters) == 0:
return self

new_filter_properties = {f.filter_property for f in new_filters}
# Only add filters that are not in the new filters
# - remove duplicate filters
# - overwrite filters with the same property but different value
# (example: cannot filter on 2 different users)
for current_filter in current_filters:
if current_filter.filter_property not in new_filter_properties:
new_filters.append(current_filter)

return self._apply_filters(new_filters)


class SyncInstruction(SyftObject):
__canonical_name__ = "SyncDecision"
Expand Down

0 comments on commit e1fc6e3

Please sign in to comment.