Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement UserCode.__call__ and fix code_history #8929

Merged
merged 14 commits into from
Jun 26, 2024
Merged
23 changes: 23 additions & 0 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@
from ...abstract_node import NodeType
from ...client.api import APIRegistry
from ...client.api import NodeIdentity
from ...client.api import generate_remote_function
from ...client.enclave_client import EnclaveMetadata
from ...node.credentials import SyftVerifyKey
from ...serde.deserialize import _deserialize
from ...serde.serializable import serializable
from ...serde.serialize import _serialize
from ...serde.signature import signature_remove_context
from ...serde.signature import signature_remove_self
from ...store.document_store import PartitionKey
from ...store.linked_obj import LinkedObject
from ...types.datetime import DateTime
Expand Down Expand Up @@ -909,6 +912,26 @@ def show_code_cell(self) -> None:
ip = get_ipython()
ip.set_next_input(warning_message + self.raw_code)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
api = self._get_api()
if isinstance(api, SyftError):
return api

signature = self.signature
signature = signature_remove_self(signature)
signature = signature_remove_context(signature)
remote_user_function = generate_remote_function(
api=api,
node_uid=self.node_uid,
signature=self.signature,
path="code.call",
make_call=api.make_call,
pre_kwargs={"uid": self.id},
warning=None,
communication_protocol=api.communication_protocol,
)
return remote_user_function(*args, **kwargs)


class UserCodeUpdate(PartialSyftObject):
__canonical_name__ = "UserCodeUpdate"
Expand Down
57 changes: 27 additions & 30 deletions packages/syft/src/syft/service/code/user_code_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@ def _submit(
) -> Result[UserCode, str]:
if not isinstance(code, UserCode):
code = code.to(UserCode, context=context) # type: ignore[unreachable]
result = self._post_user_code_transform_ops(context, code)
if isinstance(result, SyftError):
# if the validation fails, we should remove the user code status
abyesilyurt marked this conversation as resolved.
Show resolved Hide resolved
# and code version to prevent dangling status
root_context = AuthedServiceContext(
credentials=context.node.verify_key, node=context.node
)

if code.status_link is not None:
_ = context.node.get_service("usercodestatusservice").remove(
root_context, code.status_link.object_uid
)
return result

result = self.stash.set(context.credentials, code)
return result
Expand Down Expand Up @@ -133,30 +146,7 @@ def get_by_service_name(
return SyftError(message=str(result.err()))
return result.ok()

def _request_code_execution(
self,
context: AuthedServiceContext,
code: SubmitUserCode,
reason: str | None = "",
) -> Request | SyftError:
user_code: UserCode = code.to(UserCode, context=context)
result = self._validate_request_code_execution(context, user_code)
if isinstance(result, SyftError):
# if the validation fails, we should remove the user code status
# and code version to prevent dangling status
root_context = AuthedServiceContext(
credentials=context.node.verify_key, node=context.node
)

if user_code.status_link is not None:
_ = context.node.get_service("usercodestatusservice").remove(
root_context, user_code.status_link.object_uid
)
return result
result = self._request_code_execution_inner(context, user_code, reason)
return result

def _validate_request_code_execution(
def _post_user_code_transform_ops(
self,
context: AuthedServiceContext,
user_code: UserCode,
Expand Down Expand Up @@ -197,10 +187,6 @@ def _validate_request_code_execution(
if isinstance(pool_result, SyftError):
return pool_result

result = self.stash.set(context.credentials, user_code)
if result.is_err():
return SyftError(message=str(result.err()))

# Create a code history
code_history_service = context.node.get_service("codehistoryservice")
result = code_history_service.submit_version(context=context, code=user_code)
Expand All @@ -209,7 +195,7 @@ def _validate_request_code_execution(

return SyftSuccess(message="")

def _request_code_execution_inner(
def _request_code_execution(
self,
context: AuthedServiceContext,
user_code: UserCode,
Expand Down Expand Up @@ -260,7 +246,18 @@ def request_code_execution(
reason: str | None = "",
) -> Request | SyftError:
"""Request Code execution on user code"""
return self._request_code_execution(context=context, code=code, reason=reason)

# TODO: check for duplicate submissions
user_code_or_err = self._submit(context, code)
if user_code_or_err.is_err():
return SyftError(message=user_code_or_err.err())

result = self._request_code_execution(
context,
user_code_or_err.ok(),
reason,
)
return result

@service_method(path="code.get_all", name="get_all", roles=GUEST_ROLE_LEVEL)
def get_all(self, context: AuthedServiceContext) -> list[UserCode] | SyftError:
Expand Down
6 changes: 5 additions & 1 deletion packages/syft/src/syft/service/code_history/code_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@ def __getitem__(self, index: int | str) -> UserCode | SyftError:
return SyftError(
message=f"Can't access the api. You must login to {self.node_uid}"
)
if api.user_role.value >= ServiceRole.DATA_OWNER.value and index < 0:
if (
api.user.get_current_user().role.value >= ServiceRole.DATA_OWNER.value
and index < 0
):
# negative index would dynamically resolve to a different version
return SyftError(
message="For security concerns we do not allow negative indexing. \
Try using absolute values when indexing"
Expand Down
30 changes: 19 additions & 11 deletions packages/syft/src/syft/service/code_history/code_history_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from ...util.telemetry import instrument
from ..code.user_code import SubmitUserCode
from ..code.user_code import UserCode
from ..code.user_code_service import UserCodeService
from ..context import AuthedServiceContext
from ..response import SyftError
from ..response import SyftSuccess
from ..service import AbstractService
from ..service import service_method
from ..user.user_roles import DATA_OWNER_ROLE_LEVEL
from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL
from ..user.user_roles import ServiceRole
from .code_history import CodeHistoriesDict
from .code_history import CodeHistory
from .code_history import CodeHistoryView
Expand Down Expand Up @@ -49,11 +51,6 @@ def submit_version(
if result.is_err():
return SyftError(message=str(result.err()))
code = result.ok()
elif isinstance(code, UserCode): # type: ignore[unreachable]
result = user_code_service.get_by_uid(context=context, uid=code.id)
if isinstance(result, SyftError):
return result
code = result

result = self.stash.get_by_service_func_name_and_verify_key(
credentials=context.credentials,
Expand Down Expand Up @@ -120,14 +117,22 @@ def delete(
def fetch_histories_for_user(
self, context: AuthedServiceContext, user_verify_key: SyftVerifyKey
) -> CodeHistoriesDict | SyftError:
result = self.stash.get_by_verify_key(
credentials=context.credentials, user_verify_key=user_verify_key
)
if context.role in [ServiceRole.DATA_OWNER, ServiceRole.ADMIN]:
result = self.stash.get_by_verify_key(
credentials=context.node.verify_key, user_verify_key=user_verify_key
)
else:
result = self.stash.get_by_verify_key(
credentials=context.credentials, user_verify_key=user_verify_key
)

user_code_service = context.node.get_service("usercodeservice")
user_code_service: UserCodeService = context.node.get_service("usercodeservice") # type: ignore

def get_code(uid: UID) -> UserCode | SyftError:
return user_code_service.get_by_uid(context=context, uid=uid)
return user_code_service.stash.get_by_uid(
credentials=context.node.verify_key,
uid=uid,
).ok()

if result.is_ok():
code_histories = result.ok()
Expand Down Expand Up @@ -186,7 +191,10 @@ def get_history_for_user(
def get_histories_group_by_user(
self, context: AuthedServiceContext
) -> UsersCodeHistoriesDict | SyftError:
result = self.stash.get_all(credentials=context.credentials)
if context.role in [ServiceRole.DATA_OWNER, ServiceRole.ADMIN]:
result = self.stash.get_all(context.credentials, has_permission=True)
else:
result = self.stash.get_all(context.credentials)
if result.is_err():
return SyftError(message=result.err())
code_histories: list[CodeHistory] = result.ok()
Expand Down
41 changes: 34 additions & 7 deletions packages/syft/tests/syft/users/user_code_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,23 +77,50 @@ def test_user_code(worker) -> None:
assert multi_call_res.get() == result.get()


def test_duplicated_user_code(worker, guest_client: User) -> None:
def test_duplicated_user_code(worker) -> None:
worker.root_client.register(
name="Jane Doe",
email="[email protected]",
password="abc123",
password_verify="abc123",
institution="Caltech",
website="https://www.caltech.edu/",
)
ds_client = worker.root_client.login(
email="[email protected]",
password="abc123",
)

# mock_syft_func()
result = guest_client.api.services.code.request_code_execution(mock_syft_func)
result = ds_client.api.services.code.request_code_execution(mock_syft_func)
assert isinstance(result, Request)
assert len(guest_client.code.get_all()) == 1
assert len(ds_client.code.get_all()) == 1

# request the exact same code should return an error
result = guest_client.api.services.code.request_code_execution(mock_syft_func)
result = ds_client.api.services.code.request_code_execution(mock_syft_func)
assert isinstance(result, SyftError)
assert len(guest_client.code.get_all()) == 1
assert len(ds_client.code.get_all()) == 1

# request the a different function name but same content will also succeed
# flaky if not blocking
mock_syft_func_2(syft_no_node=True)
result = guest_client.api.services.code.request_code_execution(mock_syft_func_2)
result = ds_client.api.services.code.request_code_execution(mock_syft_func_2)
assert isinstance(result, Request)
assert len(guest_client.code.get_all()) == 2
assert len(ds_client.code.get_all()) == 2

code_history = ds_client.code_history
assert code_history.code_versions, "No code version found."

code_histories = worker.root_client.code_histories
user_code_history = code_histories[ds_client.logged_in_user]
assert not isinstance(code_histories, SyftError)
assert not isinstance(user_code_history, SyftError)
assert user_code_history.code_versions, "No code version found."
assert user_code_history.mock_syft_func.user_code_history[0].status is not None
assert user_code_history.mock_syft_func[0]._repr_markdown_(), "repr markdown failed"

result = user_code_history.mock_syft_func_2[0]()
assert result.get() == 1


def random_hash() -> str:
Expand Down
Loading