Skip to content

Commit

Permalink
Merge branch 'dev' into rasswanth/remove-old-enclave-code
Browse files Browse the repository at this point in the history
  • Loading branch information
rasswanth-s authored Jun 26, 2024
2 parents 1c728c6 + d678173 commit 5159e9e
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 49 deletions.
23 changes: 23 additions & 0 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@
from ...abstract_node import NodeType
from ...client.api import APIRegistry
from ...client.api import NodeIdentity
from ...client.api import generate_remote_function
from ...client.enclave_client import EnclaveMetadata
from ...node.credentials import SyftVerifyKey
from ...serde.deserialize import _deserialize
from ...serde.serializable import serializable
from ...serde.serialize import _serialize
from ...serde.signature import signature_remove_context
from ...serde.signature import signature_remove_self
from ...store.document_store import PartitionKey
from ...store.linked_obj import LinkedObject
from ...types.datetime import DateTime
Expand Down Expand Up @@ -928,6 +931,26 @@ def show_code_cell(self) -> None:
ip = get_ipython()
ip.set_next_input(warning_message + self.raw_code)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
api = self._get_api()
if isinstance(api, SyftError):
return api

signature = self.signature
signature = signature_remove_self(signature)
signature = signature_remove_context(signature)
remote_user_function = generate_remote_function(
api=api,
node_uid=self.node_uid,
signature=self.signature,
path="code.call",
make_call=api.make_call,
pre_kwargs={"uid": self.id},
warning=None,
communication_protocol=api.communication_protocol,
)
return remote_user_function(*args, **kwargs)


class UserCodeUpdate(PartialSyftObject):
__canonical_name__ = "UserCodeUpdate"
Expand Down
57 changes: 27 additions & 30 deletions packages/syft/src/syft/service/code/user_code_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,19 @@ def _submit(
) -> Result[UserCode, str]:
if not isinstance(code, UserCode):
code = code.to(UserCode, context=context) # type: ignore[unreachable]
result = self._post_user_code_transform_ops(context, code)
if isinstance(result, SyftError):
# if the validation fails, we should remove the user code status
# and code version to prevent dangling status
root_context = AuthedServiceContext(
credentials=context.node.verify_key, node=context.node
)

if code.status_link is not None:
_ = context.node.get_service("usercodestatusservice").remove(
root_context, code.status_link.object_uid
)
return result

result = self.stash.set(context.credentials, code)
return result
Expand Down Expand Up @@ -130,30 +143,7 @@ def get_by_service_name(
return SyftError(message=str(result.err()))
return result.ok()

def _request_code_execution(
self,
context: AuthedServiceContext,
code: SubmitUserCode,
reason: str | None = "",
) -> Request | SyftError:
user_code: UserCode = code.to(UserCode, context=context)
result = self._validate_request_code_execution(context, user_code)
if isinstance(result, SyftError):
# if the validation fails, we should remove the user code status
# and code version to prevent dangling status
root_context = AuthedServiceContext(
credentials=context.node.verify_key, node=context.node
)

if user_code.status_link is not None:
_ = context.node.get_service("usercodestatusservice").remove(
root_context, user_code.status_link.object_uid
)
return result
result = self._request_code_execution_inner(context, user_code, reason)
return result

def _validate_request_code_execution(
def _post_user_code_transform_ops(
self,
context: AuthedServiceContext,
user_code: UserCode,
Expand Down Expand Up @@ -194,10 +184,6 @@ def _validate_request_code_execution(
if isinstance(pool_result, SyftError):
return pool_result

result = self.stash.set(context.credentials, user_code)
if result.is_err():
return SyftError(message=str(result.err()))

# Create a code history
code_history_service = context.node.get_service("codehistoryservice")
result = code_history_service.submit_version(context=context, code=user_code)
Expand All @@ -206,7 +192,7 @@ def _validate_request_code_execution(

return SyftSuccess(message="")

def _request_code_execution_inner(
def _request_code_execution(
self,
context: AuthedServiceContext,
user_code: UserCode,
Expand Down Expand Up @@ -257,7 +243,18 @@ def request_code_execution(
reason: str | None = "",
) -> Request | SyftError:
"""Request Code execution on user code"""
return self._request_code_execution(context=context, code=code, reason=reason)

# TODO: check for duplicate submissions
user_code_or_err = self._submit(context, code)
if user_code_or_err.is_err():
return SyftError(message=user_code_or_err.err())

result = self._request_code_execution(
context,
user_code_or_err.ok(),
reason,
)
return result

@service_method(path="code.get_all", name="get_all", roles=GUEST_ROLE_LEVEL)
def get_all(self, context: AuthedServiceContext) -> list[UserCode] | SyftError:
Expand Down
6 changes: 5 additions & 1 deletion packages/syft/src/syft/service/code_history/code_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ def __getitem__(self, index: int | str) -> UserCode | SyftError:
return SyftError(
message=f"Can't access the api. You must login to {self.node_uid}"
)
if api.user_role.value >= ServiceRole.DATA_OWNER.value and index < 0:
if (
api.user.get_current_user().role.value >= ServiceRole.DATA_OWNER.value
and index < 0
):
# negative index would dynamically resolve to a different version
return SyftError(
message="For security concerns we do not allow negative indexing. \
Try using absolute values when indexing"
Expand Down
30 changes: 19 additions & 11 deletions packages/syft/src/syft/service/code_history/code_history_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from ...util.telemetry import instrument
from ..code.user_code import SubmitUserCode
from ..code.user_code import UserCode
from ..code.user_code_service import UserCodeService
from ..context import AuthedServiceContext
from ..response import SyftError
from ..response import SyftSuccess
from ..service import AbstractService
from ..service import service_method
from ..user.user_roles import DATA_OWNER_ROLE_LEVEL
from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL
from ..user.user_roles import ServiceRole
from .code_history import CodeHistoriesDict
from .code_history import CodeHistory
from .code_history import CodeHistoryView
Expand Down Expand Up @@ -49,11 +51,6 @@ def submit_version(
if result.is_err():
return SyftError(message=str(result.err()))
code = result.ok()
elif isinstance(code, UserCode): # type: ignore[unreachable]
result = user_code_service.get_by_uid(context=context, uid=code.id)
if isinstance(result, SyftError):
return result
code = result

result = self.stash.get_by_service_func_name_and_verify_key(
credentials=context.credentials,
Expand Down Expand Up @@ -120,14 +117,22 @@ def delete(
def fetch_histories_for_user(
self, context: AuthedServiceContext, user_verify_key: SyftVerifyKey
) -> CodeHistoriesDict | SyftError:
result = self.stash.get_by_verify_key(
credentials=context.credentials, user_verify_key=user_verify_key
)
if context.role in [ServiceRole.DATA_OWNER, ServiceRole.ADMIN]:
result = self.stash.get_by_verify_key(
credentials=context.node.verify_key, user_verify_key=user_verify_key
)
else:
result = self.stash.get_by_verify_key(
credentials=context.credentials, user_verify_key=user_verify_key
)

user_code_service = context.node.get_service("usercodeservice")
user_code_service: UserCodeService = context.node.get_service("usercodeservice") # type: ignore

def get_code(uid: UID) -> UserCode | SyftError:
return user_code_service.get_by_uid(context=context, uid=uid)
return user_code_service.stash.get_by_uid(
credentials=context.node.verify_key,
uid=uid,
).ok()

if result.is_ok():
code_histories = result.ok()
Expand Down Expand Up @@ -186,7 +191,10 @@ def get_history_for_user(
def get_histories_group_by_user(
self, context: AuthedServiceContext
) -> UsersCodeHistoriesDict | SyftError:
result = self.stash.get_all(credentials=context.credentials)
if context.role in [ServiceRole.DATA_OWNER, ServiceRole.ADMIN]:
result = self.stash.get_all(context.credentials, has_permission=True)
else:
result = self.stash.get_all(context.credentials)
if result.is_err():
return SyftError(message=result.err())
code_histories: list[CodeHistory] = result.ok()
Expand Down
41 changes: 34 additions & 7 deletions packages/syft/tests/syft/users/user_code_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,23 +77,50 @@ def test_user_code(worker) -> None:
assert multi_call_res.get() == result.get()


def test_duplicated_user_code(worker, guest_client: User) -> None:
def test_duplicated_user_code(worker) -> None:
worker.root_client.register(
name="Jane Doe",
email="[email protected]",
password="abc123",
password_verify="abc123",
institution="Caltech",
website="https://www.caltech.edu/",
)
ds_client = worker.root_client.login(
email="[email protected]",
password="abc123",
)

# mock_syft_func()
result = guest_client.api.services.code.request_code_execution(mock_syft_func)
result = ds_client.api.services.code.request_code_execution(mock_syft_func)
assert isinstance(result, Request)
assert len(guest_client.code.get_all()) == 1
assert len(ds_client.code.get_all()) == 1

# request the exact same code should return an error
result = guest_client.api.services.code.request_code_execution(mock_syft_func)
result = ds_client.api.services.code.request_code_execution(mock_syft_func)
assert isinstance(result, SyftError)
assert len(guest_client.code.get_all()) == 1
assert len(ds_client.code.get_all()) == 1

# request the a different function name but same content will also succeed
# flaky if not blocking
mock_syft_func_2(syft_no_node=True)
result = guest_client.api.services.code.request_code_execution(mock_syft_func_2)
result = ds_client.api.services.code.request_code_execution(mock_syft_func_2)
assert isinstance(result, Request)
assert len(guest_client.code.get_all()) == 2
assert len(ds_client.code.get_all()) == 2

code_history = ds_client.code_history
assert code_history.code_versions, "No code version found."

code_histories = worker.root_client.code_histories
user_code_history = code_histories[ds_client.logged_in_user]
assert not isinstance(code_histories, SyftError)
assert not isinstance(user_code_history, SyftError)
assert user_code_history.code_versions, "No code version found."
assert user_code_history.mock_syft_func.user_code_history[0].status is not None
assert user_code_history.mock_syft_func[0]._repr_markdown_(), "repr markdown failed"

result = user_code_history.mock_syft_func_2[0]()
assert result.get() == 1


def random_hash() -> str:
Expand Down

0 comments on commit 5159e9e

Please sign in to comment.