Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demo improvements #8137

Merged
merged 20 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
965 changes: 965 additions & 0 deletions notebooks/helm/new_policy.ipynb

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions packages/syft/src/syft/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,12 @@ def exchange_route(self, client: Self) -> Union[SyftSuccess, SyftError]:

return result

@property
def jobs(self) -> Optional[APIModule]:
if self.api.has_service("job"):
return self.api.services.job
return None

@property
def users(self) -> Optional[APIModule]:
if self.api.has_service("user"):
Expand Down
4 changes: 3 additions & 1 deletion packages/syft/src/syft/service/action/action_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,11 @@ def _user_code_execute(
code_item: UserCode,
kwargs: Dict[str, Any],
) -> Result[ActionObjectPointer, Err]:
filtered_kwargs = code_item.input_policy.filter_kwargs(
input_policy = code_item.input_policy
filtered_kwargs = input_policy.filter_kwargs(
kwargs=kwargs, context=context, code_item_id=code_item.id
)
code_item.input_policy = input_policy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Member Author

@teo-milea teo-milea Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you check the UserCode class, you can see that the input policy is actually a property that also has a setter method, so in order to save the state we need to do this. This is exactly the same mechanism as the one for the output policy


expected_input_kwargs = set()
for _inp_kwarg in code_item.input_policy.inputs.values():
Expand Down
9 changes: 9 additions & 0 deletions packages/syft/src/syft/service/blob_storage/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def mark_write_complete(
context: AuthedServiceContext,
uid: UID,
etags: List,
no_lines: Optional[int] = 0,
) -> Union[SyftError, SyftSuccess]:
result = self.stash.get_by_uid(
credentials=context.credentials,
Expand All @@ -162,6 +163,14 @@ def mark_write_complete(
if obj is None:
return SyftError(message=f"No blob storage entry exists for uid: {uid}")

obj.no_lines = no_lines
result = self.stash.update(
credentials=context.credentials,
obj=obj,
)
if result.is_err():
return SyftError(message=f"{result.err()}")

with context.node.blob_storage_client.connect() as conn:
result = conn.complete_multipart_upload(obj, etags)

Expand Down
64 changes: 47 additions & 17 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,15 +878,42 @@ def execute_byte_code(
original_print = __builtin__.print

class LocalDomainClient:
def __init__(self):
def __init__(self, context):
self.context = context
# self.print_func = copy.copy(print_func)
pass

def init_checkpoint(self, max_checkpoints):
if self.context.job is not None:
node = self.context.node
job_service = node.get_service("jobservice")
# user_service = node.get_service("userservice")
# admin_context = AuthedServiceContext(
# node=node,
# credentials=user_service.admin_verify_key(),
# role=ServiceRole.ADMIN,
# )
job = self.context.job
job.current_checkpoint = 0
job.max_checkpoints = max_checkpoints
job_service.update(self.context, job)
# return res

def checkpoint(self):
if self.context.job is not None:
node = self.context.node
job_service = node.get_service("jobservice")
job = self.context.job
job.current_checkpoint += 1
job_service.update(self.context, job)
# return res

def launch_job(self, func: UserCode, **kwargs):
# relative
from ... import UID

# get reference to node (TODO)
node = context.node
node = self.context.node
action_service = node.get_service("actionservice")
user_service = node.get_service("userservice")
user_code_service = node.get_service("usercodeservice")
Expand All @@ -898,7 +925,7 @@ def launch_job(self, func: UserCode, **kwargs):
kw2id = {}
for k, v in kwargs.items():
value = ActionObject.from_obj(v)
ptr = action_service.set(context, value)
ptr = action_service.set(self.context, value)
ptr = ptr.ok()
kw2id[k] = ptr.id

Expand All @@ -916,7 +943,7 @@ def launch_job(self, func: UserCode, **kwargs):

# TODO: throw exception for enclaves
request = user_code_service._request_code_execution_inner(
context, new_user_code
self.context, new_user_code
).ok()
admin_context = AuthedServiceContext(
node=node,
Expand All @@ -939,24 +966,24 @@ def launch_job(self, func: UserCode, **kwargs):

original_print(f"LAUNCHING JOB {func.service_func_name}")
job = node.add_api_call_to_queue(
api_call, parent_job_id=context.job_id
api_call, parent_job_id=self.context.job_id
)

# set api in global scope to enable using .get(), .wait())
user_signing_key = [
x.signing_key
for x in user_service.stash.partition.data.values()
if x.verify_key == context.credentials
if x.verify_key == self.context.credentials
][0]
user_api = node.get_api(context.credentials)
user_api = node.get_api(self.context.credentials)
user_api.signing_key = user_signing_key
# We hardcode a python connection here since we have access to the node
# TODO: this is not secure
user_api.connection = PythonConnection(node=node)

APIRegistry.set_api_for(
node_uid=node.id,
user_verify_key=context.credentials,
user_verify_key=self.context.credentials,
api=user_api,
)

Expand All @@ -966,6 +993,8 @@ def launch_job(self, func: UserCode, **kwargs):
raise ValueError(f"error while launching job:\n{e}")

if context.job is not None:
job_id = context.job_id
log_id = context.job.log_id

def print(*args, sep=" ", end="\n"):
def to_str(arg: Any) -> str:
Expand All @@ -982,11 +1011,9 @@ def to_str(arg: Any) -> str:
new_args = [to_str(arg) for arg in args]
new_str = sep.join(new_args) + end
log_service = context.node.get_service("LogService")
log_service.append(
context=context, uid=context.job.log_id, new_str=new_str
)
log_service.append(context=context, uid=log_id, new_str=new_str)
return __builtin__.print(
f"FUNCTION LOG ({context.job.log_id}):",
f"FUNCTION LOG ({job_id}):",
*new_args,
end=end,
sep=sep,
Expand All @@ -997,24 +1024,27 @@ def to_str(arg: Any) -> str:
print = original_print

if code_item.uses_domain:
kwargs["domain"] = LocalDomainClient()
kwargs["domain"] = LocalDomainClient(context=context)

stdout = StringIO()
stderr = StringIO()

# statisfy lint checker
result = None

exec(code_item.byte_code) # nosec
# res = exec(code_item.byte_code, {'print': print}, None) # nosec

_locals = locals()
_globals = {}

user_code_service = context.node.get_service("usercodeservice")
for user_code in user_code_service.stash.get_all(context.credentials).ok():
globals()[user_code.service_func_name] = user_code
globals()["print"] = print
_globals[user_code.service_func_name] = user_code
_globals["print"] = print
exec(code_item.parsed_code, _globals, locals()) # nosec

evil_string = f"{code_item.unique_func_name}(**kwargs)"
result = eval(evil_string, None, _locals) # nosec
result = eval(evil_string, _globals, _locals) # nosec

# reset print
print = original_print
Expand Down
27 changes: 27 additions & 0 deletions packages/syft/src/syft/service/job/job_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ...util.telemetry import instrument
from ..context import AuthedServiceContext
from ..response import SyftError
from ..response import SyftSuccess
from ..service import AbstractService
from ..service import service_method
from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL
Expand Down Expand Up @@ -41,6 +42,32 @@ def get(
res = res.ok()
return res

@service_method(
path="job.get_all",
name="get_all",
)
def get_all(self, context: AuthedServiceContext) -> Union[List[Job], SyftError]:
res = self.stash.get_all(context.credentials)
if res.is_err():
return SyftError(message=res.err())
else:
res = res.ok()
return res

@service_method(
path="job.udpate",
name="update",
roles=DATA_SCIENTIST_ROLE_LEVEL,
)
def update(
self, context: AuthedServiceContext, job: Job
) -> Union[SyftSuccess, SyftError]:
res = self.stash.update(context.credentials, obj=job)
if res.is_err():
return SyftError(message=res.err())
res = res.ok()
return SyftSuccess(message="Great Success!")

@service_method(
path="job.get_subjobs",
name="get_subjobs",
Expand Down
52 changes: 44 additions & 8 deletions packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# stdlib
from datetime import datetime
from enum import Enum
from typing import Any
from typing import Dict
Expand Down Expand Up @@ -52,9 +53,34 @@ class Job(SyftObject):
status: JobStatus = JobStatus.CREATED
log_id: Optional[UID]
parent_job_id: Optional[UID]
max_checkpoints: Optional[int] = 0
current_checkpoint: Optional[int] = 0
creation_time: Optional[str] = str(datetime.now())

__attr_searchable__ = ["parent_job_id"]
__repr_attrs__ = ["id", "result", "resolved"]
__repr_attrs__ = ["id", "result", "resolved", "progress", "creation_time"]

@property
def progress(self) -> str:
if self.status == JobStatus.PROCESSING:
return_string = self.status
if self.max_checkpoints > 0:
return_string += f": {self.current_checkpoint}/{self.max_checkpoints}"
if self.current_checkpoint == self.max_checkpoints:
return_string += " Almost done..."
elif self.current_checkpoint > 0:
now = datetime.now()
time_passed = now - datetime.fromisoformat(self.creation_time)
time_per_checkpoint = time_passed / self.current_checkpoint
remaining_checkpoints = self.max_checkpoints - self.current_checkpoint

# Probably need to divide by the number of consumers
remaining_time = remaining_checkpoints * time_per_checkpoint
return_string += " Remaining time: " + str(remaining_time)[:-7]
else:
return_string += " Estimating remaining time..."
return return_string
return self.status

def fetch(self) -> None:
api = APIRegistry.api_for(
Expand Down Expand Up @@ -82,6 +108,14 @@ def subjobs(self):
)
return api.services.job.get_subjobs(self.id)

@property
def owner(self):
api = APIRegistry.api_for(
node_uid=self.node_uid,
user_verify_key=self.syft_client_verify_key,
)
return api.services.user.get_current_user(self.id)

def logs(self, _print=True):
api = APIRegistry.api_for(
node_uid=self.node_uid,
Expand All @@ -107,15 +141,17 @@ def _coll_repr_(self) -> Dict[str, Any]:
logs = logs

if self.result is None:
result = ""
pass
else:
result = str(self.result.syft_action_data)
str(self.result.syft_action_data)

return {
"status": self.status,
"logs": logs,
"result": result,
"has_parent": self.has_parent,
"progress": self.progress,
"creation date": self.creation_time[:-7],
# "logs": logs,
# "result": result,
"owner email": self.owner.email,
"parent_id": str(self.parent_job_id) if self.parent_job_id else "-",
"subjobs": len(subjobs),
}

Expand Down Expand Up @@ -185,7 +221,7 @@ def set_result(
item: Job,
add_permissions: Optional[List[ActionObjectPermission]] = None,
) -> Result[Optional[Job], str]:
if item.resolved:
if True: # item.resolved:
valid = self.check_type(item, self.object_type)
if valid.is_err():
return SyftError(message=valid.err())
Expand Down
1 change: 0 additions & 1 deletion packages/syft/src/syft/service/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def partition_by_node(kwargs: Dict[str, Any]) -> Dict[str, UID]:
uid = v.id
if isinstance(v, Asset):
uid = v.action_id

if not isinstance(uid, UID):
raise Exception(f"Input {k} must have a UID not {type(v)}")

Expand Down
4 changes: 2 additions & 2 deletions packages/syft/src/syft/store/blob_storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def read(self) -> Union[SyftObject, SyftError]:
else:
return self._read_data()

def _read_data(self, stream=False):
def _read_data(self, stream=False, chunk_size=512):
# relative
from ...client.api import APIRegistry

Expand All @@ -136,7 +136,7 @@ def _read_data(self, stream=False):
response.raise_for_status()
if self.type_ is BlobFileType:
if stream:
return response.iter_lines()
return response.iter_lines(chunk_size=chunk_size)
else:
return response.content
return deserialize(response.content, from_bytes=True)
Expand Down
5 changes: 3 additions & 2 deletions packages/syft/src/syft/store/blob_storage/seaweedfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ def write(self, data: BytesIO) -> Union[SyftSuccess, SyftError]:
etags = []

try:
no_lines = 0
for part_no, (byte_chunk, url) in enumerate(
zip(_byte_chunks(data, DEFAULT_CHUNK_SIZE), self.urls),
start=1,
):
no_lines += byte_chunk.count(b"\n")
if api is not None:
blob_url = api.connection.to_blob_route(
url.url_path, host=url.host_or_ip
Expand All @@ -94,8 +96,7 @@ def write(self, data: BytesIO) -> Union[SyftSuccess, SyftError]:
syft_client_verify_key=self.syft_client_verify_key,
)
return mark_write_complete_method(
etags=etags,
uid=self.blob_storage_entry_id,
etags=etags, uid=self.blob_storage_entry_id, no_lines=no_lines
)


Expand Down
Loading