Skip to content

Commit

Permalink
Merge pull request #8294 from khoaguin/fix/duplicate-code-request
Browse files Browse the repository at this point in the history
fix: duplicated code requests
  • Loading branch information
shubham3121 authored Dec 4, 2023
2 parents 9b41f23 + 8dc5eb7 commit 8cafa14
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 3 deletions.
19 changes: 19 additions & 0 deletions notebooks/api/0.8/01-submit-code.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,25 @@
"assert len(jane_client.code.get_all()) == 1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# create the same code request with the exact same function should return an error\n",
"new_project.create_code_request(sum_trade_value_mil, jane_client)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert len(jane_client.code.get_all()) == 1"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
16 changes: 14 additions & 2 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
from .unparse import unparse

UserVerifyKeyPartitionKey = PartitionKey(key="user_verify_key", type_=SyftVerifyKey)
CodeHashPartitionKey = PartitionKey(key="code_hash", type_=int)
CodeHashPartitionKey = PartitionKey(key="code_hash", type_=str)
ServiceFuncNamePartitionKey = PartitionKey(key="service_func_name", type_=str)
SubmitTimePartitionKey = PartitionKey(key="submit_time", type_=DateTime)

Expand Down Expand Up @@ -273,6 +273,13 @@ class UserCodeV1(SyftObject):
enclave_metadata: Optional[EnclaveMetadata] = None
submit_time: Optional[DateTime]

__attr_searchable__ = [
"user_verify_key",
"status",
"service_func_name",
"code_hash",
]


@serializable()
class UserCode(SyftObject):
Expand Down Expand Up @@ -306,7 +313,12 @@ class UserCode(SyftObject):
nested_requests: Dict[str, str] = {}
nested_codes: Optional[Dict[str, Tuple[LinkedObject, Dict]]] = {}

__attr_searchable__ = ["user_verify_key", "status", "service_func_name"]
__attr_searchable__ = [
"user_verify_key",
"status",
"service_func_name",
"code_hash",
]
__attr_unique__ = []
__repr_attrs__ = ["service_func_name", "input_owners", "code_status"]

Expand Down
15 changes: 15 additions & 0 deletions packages/syft/src/syft/service/code/user_code_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,21 @@ def _request_code_execution_inner(
x in user_code.input_owner_verify_keys for x in user_code.output_readers
):
raise ValueError("outputs can only be distributed to input owners")

# check if the code with the same name and content already exists in the stash

find_results = self.stash.get_by_code_hash(
context.credentials, code_hash=user_code.code_hash
)
if find_results.is_err():
return SyftError(message=str(find_results.err()))
find_results = find_results.ok()

if find_results is not None:
return SyftError(
message="The code to be submitted (name and content) already exists"
)

result = self.stash.set(context.credentials, user_code)
if result.is_err():
return SyftError(message=str(result.err()))
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/code/user_code_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_all_by_user_verify_key(
return self.query_one(credentials=credentials, qks=qks)

def get_by_code_hash(
self, credentials: SyftVerifyKey, code_hash: int
self, credentials: SyftVerifyKey, code_hash: str
) -> Result[Optional[UserCode], str]:
qks = QueryKeys(qks=[CodeHashPartitionKey.with_obj(code_hash)])
return self.query_one(credentials=credentials, qks=qks)
Expand Down
27 changes: 27 additions & 0 deletions packages/syft/tests/syft/users/user_code_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
# syft absolute
import syft as sy
from syft.service.action.action_object import ActionObject
from syft.service.request.request import Request
from syft.service.request.request import UserCodeStatusChange
from syft.service.response import SyftError
from syft.service.user.user import User


Expand All @@ -20,6 +22,13 @@ def test_func():
return 1


@sy.syft_function(
input_policy=sy.ExactMatch(), output_policy=sy.SingleExecutionExactOutput()
)
def test_func_2():
return 1


def test_user_code(worker, guest_client: User) -> None:
test_func()
guest_client.api.services.code.request_code_execution(test_func)
Expand All @@ -38,6 +47,24 @@ def test_user_code(worker, guest_client: User) -> None:
assert isinstance(real_result, int)


def test_duplicated_user_code(worker, guest_client: User) -> None:
test_func()
result = guest_client.api.services.code.request_code_execution(test_func)
assert isinstance(result, Request)
assert len(guest_client.code.get_all()) == 1

# request the exact same code should return an error
result = guest_client.api.services.code.request_code_execution(test_func)
assert isinstance(result, SyftError)
assert len(guest_client.code.get_all()) == 1

# request the a different function name but same content will also succeed
test_func_2()
result = guest_client.api.services.code.request_code_execution(test_func_2)
assert isinstance(result, Request)
assert len(guest_client.code.get_all()) == 2


def random_hash() -> str:
return uuid.uuid4().hex[:16]

Expand Down

0 comments on commit 8cafa14

Please sign in to comment.