diff --git a/notebooks/api/0.8/05-custom-policy.ipynb b/notebooks/api/0.8/05-custom-policy.ipynb index 75b8f43e5e7..fac27bfcbe8 100644 --- a/notebooks/api/0.8/05-custom-policy.ipynb +++ b/notebooks/api/0.8/05-custom-policy.ipynb @@ -238,13 +238,207 @@ "cell_type": "code", "execution_count": null, "id": "16", + "metadata": {}, + "outputs": [], + "source": [ + "# third party\n", + "from result import Err\n", + "from result import Ok\n", + "\n", + "# syft absolute\n", + "from syft.client.api import AuthedServiceContext\n", + "from syft.client.api import NodeIdentity\n", + "\n", + "\n", + "class CustomExactMatch(sy.CustomInputPolicy):\n", + " def __init__(self, *args: Any, **kwargs: Any) -> None:\n", + " pass\n", + "\n", + " def filter_kwargs(self, kwargs, context, code_item_id):\n", + " # stdlib\n", + "\n", + " try:\n", + " allowed_inputs = self.allowed_ids_only(\n", + " allowed_inputs=self.inputs, kwargs=kwargs, context=context\n", + " )\n", + " results = self.retrieve_from_db(\n", + " code_item_id=code_item_id,\n", + " allowed_inputs=allowed_inputs,\n", + " context=context,\n", + " )\n", + " except Exception as e:\n", + " return Err(str(e))\n", + " return results\n", + "\n", + " def retrieve_from_db(self, code_item_id, allowed_inputs, context):\n", + " # syft absolute\n", + " from syft import NodeType\n", + " from syft.service.action.action_object import TwinMode\n", + "\n", + " action_service = context.node.get_service(\"actionservice\")\n", + " code_inputs = {}\n", + "\n", + " # When we are retrieving the code from the database, we need to use the node's\n", + " # verify key as the credentials. This is because when we approve the code, we\n", + " # we allow the private data to be used only for this specific code.\n", + " # but we are not modifying the permissions of the private data\n", + "\n", + " root_context = AuthedServiceContext(\n", + " node=context.node, credentials=context.node.verify_key\n", + " )\n", + " if context.node.node_type == NodeType.DOMAIN:\n", + " for var_name, arg_id in allowed_inputs.items():\n", + " kwarg_value = action_service._get(\n", + " context=root_context,\n", + " uid=arg_id,\n", + " twin_mode=TwinMode.NONE,\n", + " has_permission=True,\n", + " )\n", + " if kwarg_value.is_err():\n", + " return Err(kwarg_value.err())\n", + " code_inputs[var_name] = kwarg_value.ok()\n", + "\n", + " elif context.node.node_type == NodeType.ENCLAVE:\n", + " dict_object = action_service.get(context=root_context, uid=code_item_id)\n", + " if dict_object.is_err():\n", + " return Err(dict_object.err())\n", + " for value in dict_object.ok().syft_action_data.values():\n", + " code_inputs.update(value)\n", + "\n", + " else:\n", + " raise Exception(\n", + " f\"Invalid Node Type for Code Submission:{context.node.node_type}\"\n", + " )\n", + " return Ok(code_inputs)\n", + "\n", + " def allowed_ids_only(\n", + " self,\n", + " allowed_inputs,\n", + " kwargs,\n", + " context,\n", + " ):\n", + " # syft absolute\n", + " from syft import NodeType\n", + " from syft import UID\n", + "\n", + " if context.node.node_type == NodeType.DOMAIN:\n", + " node_identity = NodeIdentity(\n", + " node_name=context.node.name,\n", + " node_id=context.node.id,\n", + " verify_key=context.node.signing_key.verify_key,\n", + " )\n", + " allowed_inputs = allowed_inputs.get(node_identity, {})\n", + " elif context.node.node_type == NodeType.ENCLAVE:\n", + " base_dict = {}\n", + " for key in allowed_inputs.values():\n", + " base_dict.update(key)\n", + " allowed_inputs = base_dict\n", + " else:\n", + " raise Exception(\n", + " f\"Invalid Node Type for Code Submission:{context.node.node_type}\"\n", + " )\n", + " filtered_kwargs = {}\n", + " for key in allowed_inputs.keys():\n", + " if key in kwargs:\n", + " value = kwargs[key]\n", + " uid = value\n", + " if not isinstance(uid, UID):\n", + " uid = getattr(value, \"id\", None)\n", + "\n", + " if uid != allowed_inputs[key]:\n", + " raise Exception(\n", + " f\"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}\"\n", + " )\n", + " filtered_kwargs[key] = value\n", + " return filtered_kwargs\n", + "\n", + " def _is_valid(\n", + " self,\n", + " context,\n", + " usr_input_kwargs,\n", + " code_item_id,\n", + " ):\n", + " filtered_input_kwargs = self.filter_kwargs(\n", + " kwargs=usr_input_kwargs,\n", + " context=context,\n", + " code_item_id=code_item_id,\n", + " )\n", + "\n", + " if filtered_input_kwargs.is_err():\n", + " return filtered_input_kwargs\n", + "\n", + " filtered_input_kwargs = filtered_input_kwargs.ok()\n", + "\n", + " expected_input_kwargs = set()\n", + " for _inp_kwargs in self.inputs.values():\n", + " for k in _inp_kwargs.keys():\n", + " if k not in usr_input_kwargs:\n", + " return Err(f\"Function missing required keyword argument: '{k}'\")\n", + " expected_input_kwargs.update(_inp_kwargs.keys())\n", + "\n", + " permitted_input_kwargs = list(filtered_input_kwargs.keys())\n", + " not_approved_kwargs = set(expected_input_kwargs) - set(permitted_input_kwargs)\n", + " if len(not_approved_kwargs) > 0:\n", + " return Err(\n", + " f\"Input arguments: {not_approved_kwargs} to the function are not approved yet.\"\n", + " )\n", + " return Ok(True)\n", + "\n", + "\n", + "def allowed_ids_only(\n", + " self,\n", + " allowed_inputs,\n", + " kwargs,\n", + " context,\n", + "):\n", + " # syft absolute\n", + " from syft import NodeType\n", + " from syft import UID\n", + " from syft.client.api import NodeIdentity\n", + "\n", + " if context.node.node_type == NodeType.DOMAIN:\n", + " node_identity = NodeIdentity(\n", + " node_name=context.node.name,\n", + " node_id=context.node.id,\n", + " verify_key=context.node.signing_key.verify_key,\n", + " )\n", + " allowed_inputs = allowed_inputs.get(node_identity, {})\n", + " elif context.node.node_type == NodeType.ENCLAVE:\n", + " base_dict = {}\n", + " for key in allowed_inputs.values():\n", + " base_dict.update(key)\n", + " allowed_inputs = base_dict\n", + " else:\n", + " raise Exception(\n", + " f\"Invalid Node Type for Code Submission:{context.node.node_type}\"\n", + " )\n", + " filtered_kwargs = {}\n", + " for key in allowed_inputs.keys():\n", + " if key in kwargs:\n", + " value = kwargs[key]\n", + " uid = value\n", + " if not isinstance(uid, UID):\n", + " uid = getattr(value, \"id\", None)\n", + "\n", + " if uid != allowed_inputs[key]:\n", + " raise Exception(\n", + " f\"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}\"\n", + " )\n", + " filtered_kwargs[key] = value\n", + " return filtered_kwargs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "17", "metadata": { "tags": [] }, "outputs": [], "source": [ "@sy.syft_function(\n", - " input_policy=sy.ExactMatch(x=x_pointer),\n", + " input_policy=CustomExactMatch(x=x_pointer),\n", " output_policy=RepeatedCallPolicy(n_calls=10, downloadable_output_args=[\"y\"]),\n", ")\n", "def func(x):\n", @@ -254,7 +448,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "18", "metadata": { "tags": [] }, @@ -267,21 +461,13 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "19", "metadata": {}, "outputs": [], "source": [ "request_id = request.id" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "19", - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null, @@ -433,14 +619,6 @@ "if node.node_type.value == \"python\":\n", " node.land()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -459,7 +637,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.11.0rc1" }, "toc": { "base_numbering": 1, diff --git a/packages/syft/src/syft/client/syncing.py b/packages/syft/src/syft/client/syncing.py index 371b77df22a..7185cb5316e 100644 --- a/packages/syft/src/syft/client/syncing.py +++ b/packages/syft/src/syft/client/syncing.py @@ -18,7 +18,12 @@ def compare_states( - from_state: SyncState, to_state: SyncState, include_ignored: bool = False + from_state: SyncState, + to_state: SyncState, + include_ignored: bool = False, + include_same: bool = False, + filter_by_email: str | None = None, + filter_by_type: str | type | None = None, ) -> NodeDiff: # NodeDiff if ( @@ -42,11 +47,28 @@ def compare_states( high_state=high_state, direction=direction, include_ignored=include_ignored, + include_same=include_same, + filter_by_email=filter_by_email, + filter_by_type=filter_by_type, ) -def compare_clients(low_client: SyftClient, high_client: SyftClient) -> NodeDiff: - return compare_states(low_client.get_sync_state(), high_client.get_sync_state()) +def compare_clients( + from_client: SyftClient, + to_client: SyftClient, + include_ignored: bool = False, + include_same: bool = False, + filter_by_email: str | None = None, + filter_by_type: type | None = None, +) -> NodeDiff: + return compare_states( + from_client.get_sync_state(), + to_client.get_sync_state(), + include_ignored=include_ignored, + include_same=include_same, + filter_by_email=filter_by_email, + filter_by_type=filter_by_type, + ) def get_user_input_for_resolve() -> SyncDecision: diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 604e21611f4..2cbeaf31967 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -76,6 +76,7 @@ from ..policy.policy import filter_only_uids from ..policy.policy import init_policy from ..policy.policy import load_policy_code +from ..policy.policy import partition_by_node from ..policy.policy_service import PolicyService from ..response import SyftError from ..response import SyftInfo @@ -973,10 +974,13 @@ def syft_function( if input_policy is None: input_policy = EmpyInputPolicy() + init_input_kwargs = None if isinstance(input_policy, CustomInputPolicy): input_policy_type = SubmitUserPolicy.from_obj(input_policy) + init_input_kwargs = partition_by_node(input_policy.init_kwargs) else: input_policy_type = type(input_policy) + init_input_kwargs = getattr(input_policy, "init_kwargs", {}) if output_policy is None: output_policy = SingleExecutionExactOutput() @@ -992,7 +996,7 @@ def decorator(f: Any) -> SubmitUserCode: func_name=f.__name__, signature=inspect.signature(f), input_policy_type=input_policy_type, - input_policy_init_kwargs=getattr(input_policy, "init_kwargs", {}), + input_policy_init_kwargs=init_input_kwargs, output_policy_type=output_policy_type, output_policy_init_kwargs=getattr(output_policy, "init_kwargs", {}), local_function=f, diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index 04f49bac453..c5fda8e0f3e 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -854,6 +854,26 @@ def load_policy_code(user_policy: UserPolicy) -> Any: def init_policy(user_policy: UserPolicy, init_args: dict[str, Any]) -> Any: policy_class = load_policy_code(user_policy) policy_object = policy_class() + + # Unwrapp {NodeIdentity : {x: y}} -> {x: y} + # Tech debt : For input policies, we required to have NodeIdentity args beforehand, + # therefore at this stage we had to return back to the normal args. + # Maybe there's better way to do it. + if len(init_args) and isinstance(list(init_args.keys())[0], NodeIdentity): + unwrapped_init_kwargs = init_args + if len(init_args) > 1: + raise Exception("You shoudn't have more than one Node Identity.") + # Otherwise, unwrapp it + init_args = init_args[list(init_args.keys())[0]] + init_args = {k: v for k, v in init_args.items() if k != "id"} + + # For input policies, this initializer wouldn't work properly: + # 1 - Passing {NodeIdentity: {kwargs:UIDs}} as keyword args doesn't work since keys must be strings + # 2 - Passing {kwargs: UIDs} in this initializer would not trigger the partition nodes from the + # InputPolicy initializer. + # The cleanest way to solve it is by checking if it's an Input Policy, and then, setting it manually. policy_object.__user_init__(**init_args) + if isinstance(policy_object, InputPolicy): + policy_object.init_kwargs = unwrapped_init_kwargs return policy_object diff --git a/packages/syft/src/syft/service/sync/diff_state.py b/packages/syft/src/syft/service/sync/diff_state.py index cde79262c24..3cb360eafa4 100644 --- a/packages/syft/src/syft/service/sync/diff_state.py +++ b/packages/syft/src/syft/service/sync/diff_state.py @@ -1,12 +1,17 @@ # stdlib +from collections.abc import Callable from collections.abc import Iterable +from dataclasses import dataclass +import enum import html +import operator import textwrap from typing import Any from typing import ClassVar from typing import Literal # third party +from loguru import logger import pandas as pd from pydantic import model_validator from rich import box @@ -51,6 +56,7 @@ from ..request.request import Request from ..response import SyftError from ..response import SyftSuccess +from ..user.user import UserView from .sync_state import SyncState sketchy_tab = "‎ " * 4 @@ -945,19 +951,32 @@ def visual_root(self) -> ObjectDiff: @property def user_code_high(self) -> UserCode | None: """return the user code of the high side of this batch, if it exists""" - user_codes_high: list[UserCode] = [ - diff.high_obj + user_code_diff = self.user_code_diff + if user_code_diff is not None and isinstance(user_code_diff.high_obj, UserCode): + return user_code_diff.high_obj + return None + + @property + def user_code_diff(self) -> ObjectDiff | None: + """return the main user code diff of the high side of this batch, if it exists""" + user_code_diffs: list[ObjectDiff] = [ + diff for diff in self.get_dependencies(include_roots=True) - if isinstance(diff.high_obj, UserCode) + if issubclass(diff.obj_type, UserCode) ] - if len(user_codes_high) == 0: - user_code_high = None + if len(user_code_diffs) == 0: + return None else: - # NOTE we can always assume the first usercode is - # not a nested code, because diffs are sorted in depth-first order - user_code_high = user_codes_high[0] - return user_code_high + # main usercode is always the first, batches are sorted in depth-first order + return user_code_diffs[0] + + @property + def user(self) -> UserView | SyftError: + user_code_diff = self.user_code_diff + if user_code_diff is not None and isinstance(user_code_diff.low_obj, UserCode): + return user_code_diff.low_obj.user + return SyftError(message="No user found") def get_visual_hierarchy(self) -> dict[ObjectDiff, dict]: visual_root = self.visual_root @@ -1023,6 +1042,63 @@ def stage_change(self) -> None: other_batch.decision = None +class FilterProperty(enum.Enum): + USER = enum.auto() + TYPE = enum.auto() + STATUS = enum.auto() + IGNORED = enum.auto() + + def from_batch(self, batch: ObjectDiffBatch) -> Any: + if self == FilterProperty.USER: + user = batch.user + if isinstance(user, UserView): + return user.email + return None + elif self == FilterProperty.TYPE: + return batch.root_diff.obj_type.__name__.lower() + elif self == FilterProperty.STATUS: + return batch.status.lower() + elif self == FilterProperty.IGNORED: + return batch.is_ignored + else: + raise ValueError(f"Invalid property: {property}") + + +@dataclass +class NodeDiffFilter: + """ + Filter to apply to a NodeDiff object to determine if it should be included in a batch. + + Checks for `property op value` , where + property: FilterProperty - property to filter on + value: Any - value to compare against + op: callable[[Any, Any], bool] - comparison operator. Default is `operator.eq` + + If the comparison fails, the batch is excluded. + """ + + filter_property: FilterProperty + filter_value: Any + op: Callable[[Any, Any], bool] = operator.eq + + def __call__(self, batch: ObjectDiffBatch) -> bool: + filter_value = self.filter_value + if isinstance(filter_value, str): + filter_value = filter_value.lower() + + try: + p = self.filter_property.from_batch(batch) + if self.op == operator.contains: + # Contains check has reversed arg order: check if p in self.filter_value + return p in filter_value + else: + return self.op(p, filter_value) + except Exception as e: + # By default, exclude the batch if there is an error + logger.debug(f"Error filtering batch {batch} with {self}: {e}") + return False + + class NodeDiff(SyftObject): __canonical_name__ = "NodeDiff" __version__ = SYFT_OBJECT_VERSION_2 @@ -1038,6 +1114,7 @@ class NodeDiff(SyftObject): low_state: SyncState high_state: SyncState direction: SyncDirection | None + filters: list[NodeDiffFilter] = [] include_ignored: bool = False @@ -1075,6 +1152,9 @@ def from_sync_state( high_state: SyncState, direction: SyncDirection, include_ignored: bool = False, + include_same: bool = False, + filter_by_email: str | None = None, + filter_by_type: type | None = None, _include_node_status: bool = False, ) -> "NodeDiff": obj_uid_to_diff = {} @@ -1127,25 +1207,31 @@ def from_sync_state( previously_ignored_batches = low_state.ignored_batches NodeDiff.apply_previous_ignore_state(all_batches, previously_ignored_batches) - if not include_ignored: - batches = [b for b in all_batches if not b.is_ignored] - else: - batches = all_batches - - return cls( + res = cls( low_node_uid=low_state.node_uid, high_node_uid=high_state.node_uid, user_verify_key_low=low_state.syft_client_verify_key, user_verify_key_high=high_state.syft_client_verify_key, obj_uid_to_diff=obj_uid_to_diff, obj_dependencies=obj_dependencies, - batches=batches, + batches=all_batches, all_batches=all_batches, low_state=low_state, high_state=high_state, direction=direction, + filters=[], ) + res._filter( + user_email=filter_by_email, + obj_type=filter_by_type, + include_ignored=include_ignored, + include_same=include_same, + inplace=True, + ) + + return res + @staticmethod def apply_previous_ignore_state( batches: list[ObjectDiffBatch], previously_ignored_batches: dict[UID, int] @@ -1323,6 +1409,66 @@ def hierarchies( def is_same(self) -> bool: return all(object_diff.status == "SAME" for object_diff in self.diffs) + def _apply_filters( + self, filters: list[NodeDiffFilter], inplace: bool = True + ) -> Self: + """ + Apply filters to the NodeDiff object and return a new NodeDiff object + """ + batches = self.all_batches + for filter in filters: + batches = [b for b in batches if filter(b)] + + if inplace: + self.filters = filters + self.batches = batches + return self + else: + return NodeDiff( + low_node_uid=self.low_node_uid, + high_node_uid=self.high_node_uid, + user_verify_key_low=self.user_verify_key_low, + user_verify_key_high=self.user_verify_key_high, + obj_uid_to_diff=self.obj_uid_to_diff, + obj_dependencies=self.obj_dependencies, + batches=batches, + all_batches=self.all_batches, + low_state=self.low_state, + high_state=self.high_state, + direction=self.direction, + filters=filters, + ) + + def _filter( + self, + user_email: str | None = None, + obj_type: str | type | None = None, + include_ignored: bool = False, + include_same: bool = False, + inplace: bool = True, + ) -> Self: + new_filters = [] + if user_email is not None: + new_filters.append( + NodeDiffFilter(FilterProperty.USER, user_email, operator.eq) + ) + if obj_type is not None: + if isinstance(obj_type, type): + obj_type = obj_type.__name__ + new_filters.append( + NodeDiffFilter(FilterProperty.TYPE, obj_type, operator.eq) + ) + if not include_ignored: + new_filters.append( + NodeDiffFilter(FilterProperty.IGNORED, True, operator.ne) + ) + if not include_same: + new_filters.append( + NodeDiffFilter(FilterProperty.STATUS, "SAME", operator.ne) + ) + + return self._apply_filters(new_filters, inplace=inplace) + class SyncInstruction(SyftObject): __canonical_name__ = "SyncDecision"