Skip to content

Commit

Permalink
Merge pull request #9198 from OpenMined/bschell/test-for-custom-api-u…
Browse files Browse the repository at this point in the history
…ser-code-gen

Add tests for user code generator via custom API
  • Loading branch information
BrendanSchell authored Aug 28, 2024
2 parents 72f2580 + cad4a09 commit 47f6ba9
Show file tree
Hide file tree
Showing 7 changed files with 520 additions and 22 deletions.
6 changes: 2 additions & 4 deletions notebooks/api/0.8/05-custom-policy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -258,18 +258,17 @@
" def __init__(self, *args: Any, **kwargs: Any) -> None:\n",
" pass\n",
"\n",
" def filter_kwargs(self, kwargs, context, code_item_id): # stdlib\n",
" def filter_kwargs(self, kwargs, context): # stdlib\n",
" allowed_inputs = self.allowed_ids_only(\n",
" allowed_inputs=self.inputs, kwargs=kwargs, context=context\n",
" )\n",
" results = self.retrieve_from_db(\n",
" code_item_id=code_item_id,\n",
" allowed_inputs=allowed_inputs,\n",
" context=context,\n",
" )\n",
" return results\n",
"\n",
" def retrieve_from_db(self, code_item_id, allowed_inputs, context):\n",
" def retrieve_from_db(self, allowed_inputs, context):\n",
" # syft absolute\n",
" from syft import ServerType\n",
" from syft.service.action.action_object import TwinMode\n",
Expand Down Expand Up @@ -346,7 +345,6 @@
" filtered_input_kwargs = self.filter_kwargs(\n",
" kwargs=usr_input_kwargs,\n",
" context=context,\n",
" code_item_id=code_item_id,\n",
" )\n",
"\n",
" expected_input_kwargs = set()\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,23 +201,22 @@
" def __init__(self, *args: Any, **kwargs: Any) -> None:\n",
" pass\n",
"\n",
" def filter_kwargs(self, kwargs, context, code_item_id):\n",
" def filter_kwargs(self, kwargs, context):\n",
" # stdlib\n",
"\n",
" try:\n",
" allowed_inputs = self.allowed_ids_only(\n",
" allowed_inputs=self.inputs, kwargs=kwargs, context=context\n",
" )\n",
" results = self.retrieve_from_db(\n",
" code_item_id=code_item_id,\n",
" allowed_inputs=allowed_inputs,\n",
" context=context,\n",
" )\n",
" except Exception as e:\n",
" return Err(str(e))\n",
" return results\n",
"\n",
" def retrieve_from_db(self, code_item_id, allowed_inputs, context):\n",
" def retrieve_from_db(self, allowed_inputs, context):\n",
" # syft absolute\n",
" from syft import ServerType\n",
" from syft.service.action.action_object import TwinMode\n",
Expand Down Expand Up @@ -295,7 +294,6 @@
" filtered_input_kwargs = self.filter_kwargs(\n",
" kwargs=usr_input_kwargs,\n",
" context=context,\n",
" code_item_id=code_item_id,\n",
" )\n",
"\n",
" if filtered_input_kwargs.is_err():\n",
Expand Down
6 changes: 3 additions & 3 deletions packages/syft/src/syft/service/action/action_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,15 +382,15 @@ def _user_code_execute(
input_policy.is_valid(
context=context,
usr_input_kwargs=kwargs,
code_item_id=code_item.id,
)

# Filter input kwargs based on policy
filtered_kwargs = input_policy.filter_kwargs(
kwargs=kwargs, context=context, code_item_id=code_item.id
kwargs=kwargs,
context=context,
)
else:
filtered_kwargs = retrieve_from_db(code_item.id, kwargs, context).unwrap()
filtered_kwargs = retrieve_from_db(kwargs, context).unwrap()

if hasattr(input_policy, "transform_kwargs"):
filtered_kwargs = input_policy.transform_kwargs( # type: ignore
Expand Down
1 change: 0 additions & 1 deletion packages/syft/src/syft/service/code/user_code_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,6 @@ def _call(
inp_policy_validation = input_policy.is_valid(
context,
usr_input_kwargs=kwarg2id,
code_item_id=code.id,
)

if not inp_policy_validation:
Expand Down
11 changes: 1 addition & 10 deletions packages/syft/src/syft/service/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,15 +399,13 @@ def is_valid( # type: ignore
self,
context: AuthedServiceContext,
usr_input_kwargs: dict,
code_item_id: UID,
) -> bool:
raise NotImplementedError

def filter_kwargs(
self,
kwargs: dict[Any, Any],
context: AuthedServiceContext,
code_item_id: UID,
) -> dict[Any, Any]:
raise NotImplementedError

Expand Down Expand Up @@ -541,7 +539,6 @@ def filter_kwargs( # type: ignore[override]
self,
kwargs: dict[str, UID],
context: AuthedServiceContext,
code_item_id: UID,
) -> dict[Any, Any]:
try:
res = {}
Expand Down Expand Up @@ -571,12 +568,10 @@ def is_valid( # type: ignore[override]
self,
context: AuthedServiceContext,
usr_input_kwargs: dict,
code_item_id: UID,
) -> bool:
filtered_input_kwargs = self.filter_kwargs(
kwargs=usr_input_kwargs,
context=context,
code_item_id=code_item_id,
)
expected_input_kwargs = set()

Expand All @@ -601,7 +596,7 @@ def is_valid( # type: ignore[override]

@as_result(SyftException, NotFoundException, StashException)
def retrieve_from_db(
code_item_id: UID, allowed_inputs: dict[str, UID], context: AuthedServiceContext
allowed_inputs: dict[str, UID], context: AuthedServiceContext
) -> dict[str, Any]:
# relative
from ...service.action.action_object import TwinMode
Expand Down Expand Up @@ -685,14 +680,12 @@ def filter_kwargs( # type: ignore
self,
kwargs: dict[Any, Any],
context: AuthedServiceContext,
code_item_id: UID,
) -> dict[Any, Any]:
allowed_inputs = allowed_ids_only(
allowed_inputs=self.inputs, kwargs=kwargs, context=context
).unwrap()

return retrieve_from_db(
code_item_id=code_item_id,
allowed_inputs=allowed_inputs,
context=context,
).unwrap()
Expand All @@ -701,12 +694,10 @@ def is_valid( # type: ignore
self,
context: AuthedServiceContext,
usr_input_kwargs: dict,
code_item_id: UID,
) -> bool:
filtered_input_kwargs = self.filter_kwargs(
kwargs=usr_input_kwargs,
context=context,
code_item_id=code_item_id,
)

expected_input_kwargs = set()
Expand Down
146 changes: 146 additions & 0 deletions packages/syft/tests/syft/service/policy/policy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# third party
import pytest

# syft absolute
from syft import Asset
from syft import Constant
from syft import Dataset
from syft import MixedInputPolicy
from syft import syft_function
from syft.client.api import AuthedServiceContext
from syft.service.user.user_roles import ServiceRole
from syft.types.errors import SyftException


@pytest.fixture
def submit_code_with_constants_only(ds_client, worker):
input_policy = MixedInputPolicy(
endpoint=Constant(val="TEST ENDPOINT"),
query=Constant(val="TEST QUERY"),
client=ds_client,
)

@syft_function(
input_policy=input_policy,
)
def test_func():
return 1

admin_client = worker.root_client

ds_client.code.submit(test_func)

user_code = admin_client.api.services.code[0]

yield user_code


@pytest.fixture
def submit_code_with_mixed_inputs(ds_client, worker):
admin_client = worker.root_client
ds = Dataset(name="test", asset_list=[Asset(name="test", data=[1, 2], mock=[2, 3])])

admin_client.upload_dataset(ds)

asset = ds_client.datasets[0].assets[0]

mix_input_policy = MixedInputPolicy(
data=asset,
endpoint=Constant(val="TEST ENDPOINT"),
query=Constant(val="TEST QUERY"),
client=ds_client,
)

@syft_function(
input_policy=mix_input_policy,
)
def test_func_data(data, test_basic_python_type):
return data

admin_client = worker.root_client

ds_client.code.submit(test_func_data)

user_code = admin_client.api.services.code[0]

yield user_code


class TestMixedInputPolicy:
def test_constants_not_required(self, submit_code_with_constants_only):
user_code = submit_code_with_constants_only

policy = user_code.input_policy

assert policy.is_valid(context=None, usr_input_kwargs={})

def test_providing_constants_valid(self, submit_code_with_constants_only):
user_code = submit_code_with_constants_only

policy = user_code.input_policy

assert policy.is_valid(
context=None,
usr_input_kwargs={"endpoint": "TEST ENDPOINT", "query": "TEST QUERY"},
)

def test_constant_vals_can_be_retrieved_by_admin(
self, submit_code_with_constants_only
):
user_code = submit_code_with_constants_only

policy = user_code.input_policy

mapped_inputs = {k: v.val for k, v in list(policy.inputs.values())[0].items()}

assert mapped_inputs == {"endpoint": "TEST ENDPOINT", "query": "TEST QUERY"}

def test_mixed_inputs_invalid_without_same_ds(self, submit_code_with_mixed_inputs):
user_code = submit_code_with_mixed_inputs

policy = user_code.input_policy

with pytest.raises(SyftException):
policy.is_valid(context=None, usr_input_kwargs={})

def test_mixed_inputs_valid_with_same_asset(
self, worker, ds_client, submit_code_with_mixed_inputs
):
user_code = submit_code_with_mixed_inputs

policy = user_code.input_policy

asset = ds_client.datasets[0].assets[0]
ds_context = AuthedServiceContext(
server=worker,
credentials=ds_client.verify_key,
role=ServiceRole.DATA_SCIENTIST,
)
assert policy.is_valid(
context=ds_context, usr_input_kwargs={"data": asset.action_id}
)

def test_mixed_inputs_invalid_with_different_asset_raises(
self, worker, ds_client, submit_code_with_mixed_inputs
):
admin_client = worker.root_client

ds = Dataset(
name="different ds",
asset_list=[Asset(name="different asset", data=[1, 2], mock=[2, 3])],
)
admin_client.upload_dataset(ds)
user_code = submit_code_with_mixed_inputs

policy = user_code.input_policy

asset = ds_client.datasets["different ds"].assets[0]
ds_context = AuthedServiceContext(
server=worker,
credentials=ds_client.verify_key,
role=ServiceRole.DATA_SCIENTIST,
)
with pytest.raises(SyftException):
policy.is_valid(
context=ds_context, usr_input_kwargs={"data": asset.action_id}
)
Loading

0 comments on commit 47f6ba9

Please sign in to comment.