Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement UserCode.__call__ and fix code_history #8929

Merged
merged 14 commits into from
Jun 26, 2024
Merged
23 changes: 23 additions & 0 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@
from ...abstract_node import NodeType
from ...client.api import APIRegistry
from ...client.api import NodeIdentity
from ...client.api import generate_remote_function
from ...client.enclave_client import EnclaveMetadata
from ...node.credentials import SyftVerifyKey
from ...serde.deserialize import _deserialize
from ...serde.serializable import serializable
from ...serde.serialize import _serialize
from ...serde.signature import signature_remove_context
from ...serde.signature import signature_remove_self
from ...store.document_store import PartitionKey
from ...store.linked_obj import LinkedObject
from ...types.datetime import DateTime
Expand Down Expand Up @@ -891,6 +894,26 @@ def show_code_cell(self) -> None:
ip = get_ipython()
ip.set_next_input(warning_message + self.raw_code)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
api = self._get_api()
if isinstance(api, SyftError):
return api

signature = self.signature
signature = signature_remove_self(signature)
signature = signature_remove_context(signature)
remote_user_function = generate_remote_function(
api=api,
node_uid=self.node_uid,
signature=self.signature,
path="code.call",
make_call=api.make_call,
pre_kwargs={"uid": self.id},
warning=None,
communication_protocol=api.communication_protocol,
)
return remote_user_function()


class UserCodeUpdate(PartialSyftObject):
__canonical_name__ = "UserCodeUpdate"
Expand Down
57 changes: 27 additions & 30 deletions packages/syft/src/syft/service/code/user_code_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@ def _submit(
) -> Result[UserCode, str]:
if not isinstance(code, UserCode):
code = code.to(UserCode, context=context) # type: ignore[unreachable]
result = self._post_user_code_transform_ops(context, code)
if isinstance(result, SyftError):
# if the validation fails, we should remove the user code status
abyesilyurt marked this conversation as resolved.
Show resolved Hide resolved
# and code version to prevent dangling status
root_context = AuthedServiceContext(
credentials=context.node.verify_key, node=context.node
)

if code.status_link is not None:
_ = context.node.get_service("usercodestatusservice").remove(
root_context, code.status_link.object_uid
)
return result

result = self.stash.set(context.credentials, code)
return result
Expand Down Expand Up @@ -133,30 +146,7 @@ def get_by_service_name(
return SyftError(message=str(result.err()))
return result.ok()

def _request_code_execution(
self,
context: AuthedServiceContext,
code: SubmitUserCode,
reason: str | None = "",
) -> Request | SyftError:
user_code: UserCode = code.to(UserCode, context=context)
result = self._validate_request_code_execution(context, user_code)
if isinstance(result, SyftError):
# if the validation fails, we should remove the user code status
# and code version to prevent dangling status
root_context = AuthedServiceContext(
credentials=context.node.verify_key, node=context.node
)

if user_code.status_link is not None:
_ = context.node.get_service("usercodestatusservice").remove(
root_context, user_code.status_link.object_uid
)
return result
result = self._request_code_execution_inner(context, user_code, reason)
return result

def _validate_request_code_execution(
def _post_user_code_transform_ops(
self,
context: AuthedServiceContext,
user_code: UserCode,
Expand Down Expand Up @@ -197,10 +187,6 @@ def _validate_request_code_execution(
if isinstance(pool_result, SyftError):
return pool_result

result = self.stash.set(context.credentials, user_code)
if result.is_err():
return SyftError(message=str(result.err()))

# Create a code history
code_history_service = context.node.get_service("codehistoryservice")
result = code_history_service.submit_version(context=context, code=user_code)
Expand All @@ -209,7 +195,7 @@ def _validate_request_code_execution(

return SyftSuccess(message="")

def _request_code_execution_inner(
def _request_code_execution(
self,
context: AuthedServiceContext,
user_code: UserCode,
Expand Down Expand Up @@ -260,7 +246,18 @@ def request_code_execution(
reason: str | None = "",
) -> SyftSuccess | SyftError:
"""Request Code execution on user code"""
return self._request_code_execution(context=context, code=code, reason=reason)

# TODO: check for duplicate submissions
user_code_or_err = self._submit(context, code)
if user_code_or_err.is_err():
return SyftError(user_code_or_err.err())

result = self._request_code_execution(
context,
user_code_or_err.ok(),
reason,
)
return result

@service_method(path="code.get_all", name="get_all", roles=GUEST_ROLE_LEVEL)
def get_all(self, context: AuthedServiceContext) -> list[UserCode] | SyftError:
Expand Down
6 changes: 5 additions & 1 deletion packages/syft/src/syft/service/code_history/code_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def __getitem__(self, index: int | str) -> UserCode | SyftError:
return SyftError(
message=f"Can't access the api. You must login to {self.node_uid}"
)
if api.user_role.value >= ServiceRole.DATA_OWNER.value and index < 0:
if (
api.user.get_current_user().role.value >= ServiceRole.DATA_OWNER.value
and index < 0
):
# negative index would dynamically resolve to a different version
return SyftError(
message="For security concerns we do not allow negative indexing. \
Try using absolute values when indexing"
Expand Down
38 changes: 31 additions & 7 deletions packages/syft/tests/syft/users/user_code_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,23 +77,47 @@ def test_user_code(worker) -> None:
assert multi_call_res.get() == result.get()


def test_duplicated_user_code(worker, guest_client: User) -> None:
def test_duplicated_user_code(worker) -> None:
worker.root_client.register(
name="Jane Doe",
email="[email protected]",
password="abc123",
password_verify="abc123",
institution="Caltech",
website="https://www.caltech.edu/",
)
ds_client = worker.root_client.login(
email="[email protected]",
password="abc123",
)

# mock_syft_func()
result = guest_client.api.services.code.request_code_execution(mock_syft_func)
result = ds_client.api.services.code.request_code_execution(mock_syft_func)
assert isinstance(result, Request)
assert len(guest_client.code.get_all()) == 1
assert len(ds_client.code.get_all()) == 1

# request the exact same code should return an error
result = guest_client.api.services.code.request_code_execution(mock_syft_func)
result = ds_client.api.services.code.request_code_execution(mock_syft_func)
assert isinstance(result, SyftError)
assert len(guest_client.code.get_all()) == 1
assert len(ds_client.code.get_all()) == 1

# request the a different function name but same content will also succeed
# flaky if not blocking
mock_syft_func_2(syft_no_node=True)
result = guest_client.api.services.code.request_code_execution(mock_syft_func_2)
result = ds_client.api.services.code.request_code_execution(mock_syft_func_2)
assert isinstance(result, Request)
assert len(guest_client.code.get_all()) == 2
assert len(ds_client.code.get_all()) == 2

code_history = ds_client.code_history
assert code_history.code_versions, "No code version found."

code_histories = worker.root_client.code_histories
user_code_history = code_histories[ds_client.logged_in_user]
assert not isinstance(user_code_history, SyftError)
assert user_code_history.code_versions, "No code version found."

result = user_code_history.mock_syft_func_2[0]()
assert result.get() == 1


def random_hash() -> str:
Expand Down
Loading