Skip to content

Commit

Permalink
Merge pull request #8132 from khoaguin/bugfix/request-denial-reason-n…
Browse files Browse the repository at this point in the history
…ot-shown

fix: show reason for request denial in the error
  • Loading branch information
shubham3121 authored Oct 9, 2023
2 parents 6f97c46 + 0676dfb commit 47875b9
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@
"metadata": {},
"outputs": [],
"source": [
"for st in status.status_dict.values():\n",
"for st, _ in status.status_dict.values():\n",
" assert st == sy.service.request.request.UserCodeStatus.APPROVED"
]
},
Expand Down
4 changes: 2 additions & 2 deletions packages/syft/src/syft/external/oblv/oblv_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,9 @@ def send_user_code_inputs_to_enclave(
message=f"Unable to find {user_code_id} in {type(user_code_service)}"
)
user_code = user_code.ok()

reason: str = context.extra_kwargs.get("reason", "")
res = user_code.status.mutate(
value=UserCodeStatus.APPROVED,
value=(UserCodeStatus.APPROVED, reason),
node_name=node_name,
verify_key=context.credentials,
)
Expand Down
49 changes: 32 additions & 17 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Type
from typing import Union

Expand Down Expand Up @@ -109,7 +110,7 @@ def __hash__(self) -> int:
# as status is in attr_searchable
@serializable(attrs=["status_dict"])
class UserCodeStatusCollection(SyftHashableObject):
status_dict: Dict[NodeIdentity, UserCodeStatus] = {}
status_dict: Dict[NodeIdentity, Tuple[UserCodeStatus, str]] = {}

def __init__(self, status_dict: Dict):
self.status_dict = status_dict
Expand All @@ -126,49 +127,59 @@ def _repr_html_(self):
<h3 style="line-height: 25%; margin-top: 25px;">User Code Status</h3>
<p style="margin-left: 3px;">
"""
for node_identity, status in self.status_dict.items():
for node_identity, (status, reason) in self.status_dict.items():
node_name_str = f"{node_identity.node_name}"
uid_str = f"{node_identity.node_id}"
status_str = f"{status.value}"
string += f"""
&#x2022; <strong>UID: </strong>{uid_str}&nbsp;
<strong>Node name: </strong>{node_name_str}&nbsp;
<strong>Status: </strong>{status_str}
<strong>Status: </strong>{status_str};
<strong>Reason: </strong>{reason}
<br>
"""
string += "</p></div>"
return string

def __repr_syft_nested__(self):
string = ""
for node_identity, status in self.status_dict.items():
string += f"{node_identity.node_name}: {status}<br>"
for node_identity, (status, reason) in self.status_dict.items():
string += f"{node_identity.node_name}: {status}, {reason}<br>"
return string

def get_status_message(self):
if self.approved:
return SyftSuccess(message=f"{type(self)} approved")
denial_string = ""
string = ""
for node_identity, status in self.status_dict.items():
string += f"Code status on node '{node_identity.node_name}' is '{status}'. "
for node_identity, (status, reason) in self.status_dict.items():
denial_string += f"Code status on node '{node_identity.node_name}' is '{status}'. Reason: {reason}"
if not reason.endswith("."):
denial_string += "."
string += f"Code status on node '{node_identity.node_name}' is '{status}'."
if self.denied:
return SyftError(message=f"{type(self)} Your code cannot be run: {string}")
return SyftError(
message=f"{type(self)} Your code cannot be run: {denial_string}"
)
else:
return SyftNotReady(
message=f"{type(self)} Your code is waiting for approval. {string}"
)

@property
def approved(self) -> bool:
return all(x == UserCodeStatus.APPROVED for x in self.status_dict.values())
return all(x == UserCodeStatus.APPROVED for x, _ in self.status_dict.values())

@property
def denied(self) -> bool:
return UserCodeStatus.DENIED in self.status_dict.values()
for status, _ in self.status_dict.values():
if status == UserCodeStatus.DENIED:
return True
return False

def for_user_context(self, context: AuthedServiceContext) -> UserCodeStatus:
if context.node.node_type == NodeType.ENCLAVE:
keys = set(self.status_dict.values())
keys = {status for status, _ in self.status_dict.values()}
if len(keys) == 1 and UserCodeStatus.APPROVED in keys:
return UserCodeStatus.APPROVED
elif UserCodeStatus.PENDING in keys and UserCodeStatus.DENIED not in keys:
Expand All @@ -185,7 +196,7 @@ def for_user_context(self, context: AuthedServiceContext) -> UserCodeStatus:
verify_key=context.node.signing_key.verify_key,
)
if node_identity in self.status_dict:
return self.status_dict[node_identity]
return self.status_dict[node_identity][0]
else:
raise Exception(
f"Code Object does not contain {context.node.name} Domain's data"
Expand All @@ -196,7 +207,11 @@ def for_user_context(self, context: AuthedServiceContext) -> UserCodeStatus:
)

def mutate(
self, value: UserCodeStatus, node_name: str, node_id, verify_key: SyftVerifyKey
self,
value: Tuple[UserCodeStatus, str],
node_name: str,
node_id,
verify_key: SyftVerifyKey,
) -> Union[SyftError, Self]:
node_identity = NodeIdentity(
node_name=node_name, node_id=node_id, verify_key=verify_key
Expand Down Expand Up @@ -251,7 +266,7 @@ def __setattr__(self, key: str, value: Any) -> None:
return super().__setattr__(key, value)

def _coll_repr_(self) -> Dict[str, Any]:
status = list(self.status.status_dict.values())[0].value
status = [status for status, _ in self.status.status_dict.values()][0].value
if status == UserCodeStatus.PENDING.value:
badge_color = "badge-purple"
elif status == UserCodeStatus.APPROVED.value:
Expand Down Expand Up @@ -296,7 +311,7 @@ def output_readers(self) -> List[SyftVerifyKey]:
@property
def code_status(self) -> list:
status_list = []
for node_view, status in self.status.status_dict.items():
for node_view, (status, _) in self.status.status_dict.items():
status_list.append(
f"Node: {node_view.node_name}, Status: {status.value}",
)
Expand Down Expand Up @@ -795,7 +810,7 @@ def add_custom_status(context: TransformContext) -> TransformContext:
verify_key=context.node.signing_key.verify_key,
)
context.output["status"] = UserCodeStatusCollection(
status_dict={node_identity: UserCodeStatus.PENDING}
status_dict={node_identity: (UserCodeStatus.PENDING, "")}
)
# if node_identity in input_keys or len(input_keys) == 0:
# context.output["status"] = UserCodeStatusContext(
Expand All @@ -804,7 +819,7 @@ def add_custom_status(context: TransformContext) -> TransformContext:
# else:
# raise ValueError(f"Invalid input keys: {input_keys} for {node_identity}")
elif context.node.node_type == NodeType.ENCLAVE:
status_dict = {key: UserCodeStatus.PENDING for key in input_keys}
status_dict = {key: (UserCodeStatus.PENDING, "") for key in input_keys}
context.output["status"] = UserCodeStatusCollection(status_dict=status_dict)
else:
raise NotImplementedError(
Expand Down
5 changes: 4 additions & 1 deletion packages/syft/src/syft/service/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,12 @@ class ChangeContext(SyftBaseObject):
node: Optional[AbstractNode] = None
approving_user_credentials: Optional[SyftVerifyKey]
requesting_user_credentials: Optional[SyftVerifyKey]
extra_kwargs: Dict = {}

@staticmethod
def from_service(context: AuthedServiceContext) -> Self:
return ChangeContext(
node=context.node, approving_user_credentials=context.credentials
node=context.node,
approving_user_credentials=context.credentials,
extra_kwargs=context.extra_kwargs,
)
3 changes: 2 additions & 1 deletion packages/syft/src/syft/service/enclave/enclave_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def send_user_code_inputs_to_enclave(
if isinstance(user_code, SyftError):
return user_code

reason: str = context.extra_kwargs.get("reason", "")
status_update = user_code.status.mutate(
value=UserCodeStatus.APPROVED,
value=(UserCodeStatus.APPROVED, reason),
node_name=node_name,
node_id=node_id,
verify_key=context.credentials,
Expand Down
5 changes: 3 additions & 2 deletions packages/syft/src/syft/service/request/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,16 +807,17 @@ def valid(self) -> Union[SyftSuccess, SyftError]:
return SyftSuccess(message=f"{type(self)} valid")

def mutate(self, obj: UserCode, context: ChangeContext, undo: bool) -> Any:
reason: str = context.extra_kwargs.get("reason", "")
if not undo:
res = obj.status.mutate(
value=self.value,
value=(self.value, reason),
node_name=context.node.name,
node_id=context.node.id,
verify_key=context.node.signing_key.verify_key,
)
else:
res = obj.status.mutate(
value=UserCodeStatus.DENIED,
value=(UserCodeStatus.DENIED, reason),
node_name=context.node.name,
node_id=context.node.id,
verify_key=context.node.signing_key.verify_key,
Expand Down
3 changes: 2 additions & 1 deletion packages/syft/src/syft/service/request/request_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def undo(
if request is None:
return SyftError(message=f"Request with uid: {uid} does not exists.")

context.extra_kwargs["reason"] = reason
result = request.undo(context=context)

if result.is_err():
Expand All @@ -254,8 +255,8 @@ def undo(
linked_obj=link,
)
send_notification = context.node.get_service_method(NotificationService.send)
send_notification(context=context, notification=notification)

result = send_notification(context=context, notification=notification)
return SyftSuccess(message=f"Request {uid} successfully denied !")

def save(
Expand Down

0 comments on commit 47875b9

Please sign in to comment.