Skip to content

Commit

Permalink
Merge branch 'dev' into remove-jax-haiku
Browse files Browse the repository at this point in the history
  • Loading branch information
shubham3121 authored May 24, 2024
2 parents 55487fc + cb6dfe6 commit f41c592
Show file tree
Hide file tree
Showing 5 changed files with 410 additions and 40 deletions.
218 changes: 198 additions & 20 deletions notebooks/api/0.8/05-custom-policy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,207 @@
"cell_type": "code",
"execution_count": null,
"id": "16",
"metadata": {},
"outputs": [],
"source": [
"# third party\n",
"from result import Err\n",
"from result import Ok\n",
"\n",
"# syft absolute\n",
"from syft.client.api import AuthedServiceContext\n",
"from syft.client.api import NodeIdentity\n",
"\n",
"\n",
"class CustomExactMatch(sy.CustomInputPolicy):\n",
" def __init__(self, *args: Any, **kwargs: Any) -> None:\n",
" pass\n",
"\n",
" def filter_kwargs(self, kwargs, context, code_item_id):\n",
" # stdlib\n",
"\n",
" try:\n",
" allowed_inputs = self.allowed_ids_only(\n",
" allowed_inputs=self.inputs, kwargs=kwargs, context=context\n",
" )\n",
" results = self.retrieve_from_db(\n",
" code_item_id=code_item_id,\n",
" allowed_inputs=allowed_inputs,\n",
" context=context,\n",
" )\n",
" except Exception as e:\n",
" return Err(str(e))\n",
" return results\n",
"\n",
" def retrieve_from_db(self, code_item_id, allowed_inputs, context):\n",
" # syft absolute\n",
" from syft import NodeType\n",
" from syft.service.action.action_object import TwinMode\n",
"\n",
" action_service = context.node.get_service(\"actionservice\")\n",
" code_inputs = {}\n",
"\n",
" # When we are retrieving the code from the database, we need to use the node's\n",
" # verify key as the credentials. This is because when we approve the code, we\n",
" # we allow the private data to be used only for this specific code.\n",
" # but we are not modifying the permissions of the private data\n",
"\n",
" root_context = AuthedServiceContext(\n",
" node=context.node, credentials=context.node.verify_key\n",
" )\n",
" if context.node.node_type == NodeType.DOMAIN:\n",
" for var_name, arg_id in allowed_inputs.items():\n",
" kwarg_value = action_service._get(\n",
" context=root_context,\n",
" uid=arg_id,\n",
" twin_mode=TwinMode.NONE,\n",
" has_permission=True,\n",
" )\n",
" if kwarg_value.is_err():\n",
" return Err(kwarg_value.err())\n",
" code_inputs[var_name] = kwarg_value.ok()\n",
"\n",
" elif context.node.node_type == NodeType.ENCLAVE:\n",
" dict_object = action_service.get(context=root_context, uid=code_item_id)\n",
" if dict_object.is_err():\n",
" return Err(dict_object.err())\n",
" for value in dict_object.ok().syft_action_data.values():\n",
" code_inputs.update(value)\n",
"\n",
" else:\n",
" raise Exception(\n",
" f\"Invalid Node Type for Code Submission:{context.node.node_type}\"\n",
" )\n",
" return Ok(code_inputs)\n",
"\n",
" def allowed_ids_only(\n",
" self,\n",
" allowed_inputs,\n",
" kwargs,\n",
" context,\n",
" ):\n",
" # syft absolute\n",
" from syft import NodeType\n",
" from syft import UID\n",
"\n",
" if context.node.node_type == NodeType.DOMAIN:\n",
" node_identity = NodeIdentity(\n",
" node_name=context.node.name,\n",
" node_id=context.node.id,\n",
" verify_key=context.node.signing_key.verify_key,\n",
" )\n",
" allowed_inputs = allowed_inputs.get(node_identity, {})\n",
" elif context.node.node_type == NodeType.ENCLAVE:\n",
" base_dict = {}\n",
" for key in allowed_inputs.values():\n",
" base_dict.update(key)\n",
" allowed_inputs = base_dict\n",
" else:\n",
" raise Exception(\n",
" f\"Invalid Node Type for Code Submission:{context.node.node_type}\"\n",
" )\n",
" filtered_kwargs = {}\n",
" for key in allowed_inputs.keys():\n",
" if key in kwargs:\n",
" value = kwargs[key]\n",
" uid = value\n",
" if not isinstance(uid, UID):\n",
" uid = getattr(value, \"id\", None)\n",
"\n",
" if uid != allowed_inputs[key]:\n",
" raise Exception(\n",
" f\"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}\"\n",
" )\n",
" filtered_kwargs[key] = value\n",
" return filtered_kwargs\n",
"\n",
" def _is_valid(\n",
" self,\n",
" context,\n",
" usr_input_kwargs,\n",
" code_item_id,\n",
" ):\n",
" filtered_input_kwargs = self.filter_kwargs(\n",
" kwargs=usr_input_kwargs,\n",
" context=context,\n",
" code_item_id=code_item_id,\n",
" )\n",
"\n",
" if filtered_input_kwargs.is_err():\n",
" return filtered_input_kwargs\n",
"\n",
" filtered_input_kwargs = filtered_input_kwargs.ok()\n",
"\n",
" expected_input_kwargs = set()\n",
" for _inp_kwargs in self.inputs.values():\n",
" for k in _inp_kwargs.keys():\n",
" if k not in usr_input_kwargs:\n",
" return Err(f\"Function missing required keyword argument: '{k}'\")\n",
" expected_input_kwargs.update(_inp_kwargs.keys())\n",
"\n",
" permitted_input_kwargs = list(filtered_input_kwargs.keys())\n",
" not_approved_kwargs = set(expected_input_kwargs) - set(permitted_input_kwargs)\n",
" if len(not_approved_kwargs) > 0:\n",
" return Err(\n",
" f\"Input arguments: {not_approved_kwargs} to the function are not approved yet.\"\n",
" )\n",
" return Ok(True)\n",
"\n",
"\n",
"def allowed_ids_only(\n",
" self,\n",
" allowed_inputs,\n",
" kwargs,\n",
" context,\n",
"):\n",
" # syft absolute\n",
" from syft import NodeType\n",
" from syft import UID\n",
" from syft.client.api import NodeIdentity\n",
"\n",
" if context.node.node_type == NodeType.DOMAIN:\n",
" node_identity = NodeIdentity(\n",
" node_name=context.node.name,\n",
" node_id=context.node.id,\n",
" verify_key=context.node.signing_key.verify_key,\n",
" )\n",
" allowed_inputs = allowed_inputs.get(node_identity, {})\n",
" elif context.node.node_type == NodeType.ENCLAVE:\n",
" base_dict = {}\n",
" for key in allowed_inputs.values():\n",
" base_dict.update(key)\n",
" allowed_inputs = base_dict\n",
" else:\n",
" raise Exception(\n",
" f\"Invalid Node Type for Code Submission:{context.node.node_type}\"\n",
" )\n",
" filtered_kwargs = {}\n",
" for key in allowed_inputs.keys():\n",
" if key in kwargs:\n",
" value = kwargs[key]\n",
" uid = value\n",
" if not isinstance(uid, UID):\n",
" uid = getattr(value, \"id\", None)\n",
"\n",
" if uid != allowed_inputs[key]:\n",
" raise Exception(\n",
" f\"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}\"\n",
" )\n",
" filtered_kwargs[key] = value\n",
" return filtered_kwargs"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"@sy.syft_function(\n",
" input_policy=sy.ExactMatch(x=x_pointer),\n",
" input_policy=CustomExactMatch(x=x_pointer),\n",
" output_policy=RepeatedCallPolicy(n_calls=10, downloadable_output_args=[\"y\"]),\n",
")\n",
"def func(x):\n",
Expand All @@ -254,7 +448,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "17",
"id": "18",
"metadata": {
"tags": []
},
Expand All @@ -267,21 +461,13 @@
{
"cell_type": "code",
"execution_count": null,
"id": "18",
"id": "19",
"metadata": {},
"outputs": [],
"source": [
"request_id = request.id"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -433,14 +619,6 @@
"if node.node_type.value == \"python\":\n",
" node.land()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "32",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -459,7 +637,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
"version": "3.11.0rc1"
},
"toc": {
"base_numbering": 1,
Expand Down
28 changes: 25 additions & 3 deletions packages/syft/src/syft/client/syncing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@


def compare_states(
from_state: SyncState, to_state: SyncState, include_ignored: bool = False
from_state: SyncState,
to_state: SyncState,
include_ignored: bool = False,
include_same: bool = False,
filter_by_email: str | None = None,
filter_by_type: str | type | None = None,
) -> NodeDiff:
# NodeDiff
if (
Expand All @@ -42,11 +47,28 @@ def compare_states(
high_state=high_state,
direction=direction,
include_ignored=include_ignored,
include_same=include_same,
filter_by_email=filter_by_email,
filter_by_type=filter_by_type,
)


def compare_clients(low_client: SyftClient, high_client: SyftClient) -> NodeDiff:
return compare_states(low_client.get_sync_state(), high_client.get_sync_state())
def compare_clients(
from_client: SyftClient,
to_client: SyftClient,
include_ignored: bool = False,
include_same: bool = False,
filter_by_email: str | None = None,
filter_by_type: type | None = None,
) -> NodeDiff:
return compare_states(
from_client.get_sync_state(),
to_client.get_sync_state(),
include_ignored=include_ignored,
include_same=include_same,
filter_by_email=filter_by_email,
filter_by_type=filter_by_type,
)


def get_user_input_for_resolve() -> SyncDecision:
Expand Down
6 changes: 5 additions & 1 deletion packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
from ..policy.policy import filter_only_uids
from ..policy.policy import init_policy
from ..policy.policy import load_policy_code
from ..policy.policy import partition_by_node
from ..policy.policy_service import PolicyService
from ..response import SyftError
from ..response import SyftInfo
Expand Down Expand Up @@ -973,10 +974,13 @@ def syft_function(
if input_policy is None:
input_policy = EmpyInputPolicy()

init_input_kwargs = None
if isinstance(input_policy, CustomInputPolicy):
input_policy_type = SubmitUserPolicy.from_obj(input_policy)
init_input_kwargs = partition_by_node(input_policy.init_kwargs)
else:
input_policy_type = type(input_policy)
init_input_kwargs = getattr(input_policy, "init_kwargs", {})

if output_policy is None:
output_policy = SingleExecutionExactOutput()
Expand All @@ -992,7 +996,7 @@ def decorator(f: Any) -> SubmitUserCode:
func_name=f.__name__,
signature=inspect.signature(f),
input_policy_type=input_policy_type,
input_policy_init_kwargs=getattr(input_policy, "init_kwargs", {}),
input_policy_init_kwargs=init_input_kwargs,
output_policy_type=output_policy_type,
output_policy_init_kwargs=getattr(output_policy, "init_kwargs", {}),
local_function=f,
Expand Down
20 changes: 20 additions & 0 deletions packages/syft/src/syft/service/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,26 @@ def load_policy_code(user_policy: UserPolicy) -> Any:
def init_policy(user_policy: UserPolicy, init_args: dict[str, Any]) -> Any:
policy_class = load_policy_code(user_policy)
policy_object = policy_class()

# Unwrapp {NodeIdentity : {x: y}} -> {x: y}
# Tech debt : For input policies, we required to have NodeIdentity args beforehand,
# therefore at this stage we had to return back to the normal args.
# Maybe there's better way to do it.
if len(init_args) and isinstance(list(init_args.keys())[0], NodeIdentity):
unwrapped_init_kwargs = init_args
if len(init_args) > 1:
raise Exception("You shoudn't have more than one Node Identity.")
# Otherwise, unwrapp it
init_args = init_args[list(init_args.keys())[0]]

init_args = {k: v for k, v in init_args.items() if k != "id"}

# For input policies, this initializer wouldn't work properly:
# 1 - Passing {NodeIdentity: {kwargs:UIDs}} as keyword args doesn't work since keys must be strings
# 2 - Passing {kwargs: UIDs} in this initializer would not trigger the partition nodes from the
# InputPolicy initializer.
# The cleanest way to solve it is by checking if it's an Input Policy, and then, setting it manually.
policy_object.__user_init__(**init_args)
if isinstance(policy_object, InputPolicy):
policy_object.init_kwargs = unwrapped_init_kwargs
return policy_object
Loading

0 comments on commit f41c592

Please sign in to comment.