Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
koenvanderveen committed Oct 16, 2023
2 parents 0350660 + 1133f8f commit 395e192
Show file tree
Hide file tree
Showing 14 changed files with 1,221 additions and 132 deletions.
178 changes: 90 additions & 88 deletions notebooks/helm/helm_syft.ipynb

Large diffs are not rendered by default.

965 changes: 965 additions & 0 deletions notebooks/helm/new_policy.ipynb

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions packages/syft/src/syft/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,12 @@ def exchange_route(self, client: Self) -> Union[SyftSuccess, SyftError]:

return result

@property
def jobs(self) -> Optional[APIModule]:
if self.api.has_service("job"):
return self.api.services.job
return None

@property
def users(self) -> Optional[APIModule]:
if self.api.has_service("user"):
Expand Down
1 change: 0 additions & 1 deletion packages/syft/src/syft/node/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,6 @@ def init_queue_manager(self, queue_config: Optional[QueueConfig]):
for _ in range(queue_config.client_config.n_consumers):
if address is None:
raise ValueError("address unknown for consumers")
print("INITIALIZING CONSUMER")
consumer = self.queue_manager.create_consumer(
message_handler, address=address
)
Expand Down
5 changes: 4 additions & 1 deletion packages/syft/src/syft/service/action/action_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,12 @@ def _user_code_execute(
code_item: UserCode,
kwargs: Dict[str, Any],
) -> Result[ActionObjectPointer, Err]:
filtered_kwargs = code_item.input_policy.filter_kwargs(
input_policy = code_item.input_policy
filtered_kwargs = input_policy.filter_kwargs(
kwargs=kwargs, context=context, code_item_id=code_item.id
)
# update input policy to track any input state
code_item.input_policy = input_policy

expected_input_kwargs = set()
for _inp_kwarg in code_item.input_policy.inputs.values():
Expand Down
9 changes: 9 additions & 0 deletions packages/syft/src/syft/service/blob_storage/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def mark_write_complete(
context: AuthedServiceContext,
uid: UID,
etags: List,
no_lines: Optional[int] = 0,
) -> Union[SyftError, SyftSuccess]:
result = self.stash.get_by_uid(
credentials=context.credentials,
Expand All @@ -162,6 +163,14 @@ def mark_write_complete(
if obj is None:
return SyftError(message=f"No blob storage entry exists for uid: {uid}")

obj.no_lines = no_lines
result = self.stash.update(
credentials=context.credentials,
obj=obj,
)
if result.is_err():
return SyftError(message=f"{result.err()}")

with context.node.blob_storage_client.connect() as conn:
result = conn.complete_multipart_upload(obj, etags)

Expand Down
54 changes: 36 additions & 18 deletions packages/syft/src/syft/service/code/user_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,15 +893,32 @@ def execute_byte_code(
original_print = __builtin__.print

class LocalDomainClient:
def __init__(self):
pass
def __init__(self, context):
self.context = context

def init_progress(self, n_iters):
if self.context.job is not None:
node = self.context.node
job_service = node.get_service("jobservice")
job = self.context.job
job.current_iter = 0
job.n_iters = n_iters
job_service.update(self.context, job)

def update_progress(self, n=1):
if self.context.job is not None:
node = self.context.node
job_service = node.get_service("jobservice")
job = self.context.job
job.current_iter += n
job_service.update(self.context, job)

def launch_job(self, func: UserCode, **kwargs):
# relative
from ... import UID

# get reference to node (TODO)
node = context.node
node = self.context.node
action_service = node.get_service("actionservice")
user_service = node.get_service("userservice")
user_code_service = node.get_service("usercodeservice")
Expand All @@ -913,7 +930,7 @@ def launch_job(self, func: UserCode, **kwargs):
kw2id = {}
for k, v in kwargs.items():
value = ActionObject.from_obj(v)
ptr = action_service.set(context, value)
ptr = action_service.set(self.context, value)
ptr = ptr.ok()
kw2id[k] = ptr.id

Expand All @@ -931,7 +948,7 @@ def launch_job(self, func: UserCode, **kwargs):

# TODO: throw exception for enclaves
request = user_code_service._request_code_execution_inner(
context, new_user_code
self.context, new_user_code
).ok()
admin_context = AuthedServiceContext(
node=node,
Expand All @@ -954,24 +971,24 @@ def launch_job(self, func: UserCode, **kwargs):

original_print(f"LAUNCHING JOB {func.service_func_name}")
job = node.add_api_call_to_queue(
api_call, parent_job_id=context.job_id
api_call, parent_job_id=self.context.job_id
)

# set api in global scope to enable using .get(), .wait())
user_signing_key = [
x.signing_key
for x in user_service.stash.partition.data.values()
if x.verify_key == context.credentials
if x.verify_key == self.context.credentials
][0]
user_api = node.get_api(context.credentials)
user_api = node.get_api(self.context.credentials)
user_api.signing_key = user_signing_key
# We hardcode a python connection here since we have access to the node
# TODO: this is not secure
user_api.connection = PythonConnection(node=node)

APIRegistry.set_api_for(
node_uid=node.id,
user_verify_key=context.credentials,
user_verify_key=self.context.credentials,
api=user_api,
)

Expand All @@ -981,6 +998,8 @@ def launch_job(self, func: UserCode, **kwargs):
raise ValueError(f"error while launching job:\n{e}")

if context.job is not None:
job_id = context.job_id
log_id = context.job.log_id

def print(*args, sep=" ", end="\n"):
def to_str(arg: Any) -> str:
Expand All @@ -997,11 +1016,9 @@ def to_str(arg: Any) -> str:
new_args = [to_str(arg) for arg in args]
new_str = sep.join(new_args) + end
log_service = context.node.get_service("LogService")
log_service.append(
context=context, uid=context.job.log_id, new_str=new_str
)
log_service.append(context=context, uid=log_id, new_str=new_str)
return __builtin__.print(
f"FUNCTION LOG ({context.job.log_id}):",
f"FUNCTION LOG ({job_id}):",
*new_args,
end=end,
sep=sep,
Expand All @@ -1012,24 +1029,25 @@ def to_str(arg: Any) -> str:
print = original_print

if code_item.uses_domain:
kwargs["domain"] = LocalDomainClient()
kwargs["domain"] = LocalDomainClient(context=context)

stdout = StringIO()
stderr = StringIO()

# statisfy lint checker
result = None

exec(code_item.byte_code) # nosec
_locals = locals()
_globals = {}

user_code_service = context.node.get_service("usercodeservice")
for user_code in user_code_service.stash.get_all(context.credentials).ok():
globals()[user_code.service_func_name] = user_code
globals()["print"] = print
_globals[user_code.service_func_name] = user_code
_globals["print"] = print
exec(code_item.parsed_code, _globals, locals()) # nosec

evil_string = f"{code_item.unique_func_name}(**kwargs)"
result = eval(evil_string, None, _locals) # nosec
result = eval(evil_string, _globals, _locals) # nosec

# reset print
print = original_print
Expand Down
27 changes: 27 additions & 0 deletions packages/syft/src/syft/service/job/job_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ...util.telemetry import instrument
from ..context import AuthedServiceContext
from ..response import SyftError
from ..response import SyftSuccess
from ..service import AbstractService
from ..service import service_method
from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL
Expand Down Expand Up @@ -41,6 +42,32 @@ def get(
res = res.ok()
return res

@service_method(
path="job.get_all",
name="get_all",
)
def get_all(self, context: AuthedServiceContext) -> Union[List[Job], SyftError]:
res = self.stash.get_all(context.credentials)
if res.is_err():
return SyftError(message=res.err())
else:
res = res.ok()
return res

@service_method(
path="job.update",
name="update",
roles=DATA_SCIENTIST_ROLE_LEVEL,
)
def update(
self, context: AuthedServiceContext, job: Job
) -> Union[SyftSuccess, SyftError]:
res = self.stash.update(context.credentials, obj=job)
if res.is_err():
return SyftError(message=res.err())
res = res.ok()
return SyftSuccess(message="Great Success!")

@service_method(
path="job.get_subjobs",
name="get_subjobs",
Expand Down
61 changes: 48 additions & 13 deletions packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# stdlib
from datetime import datetime
from enum import Enum
from typing import Any
from typing import Dict
Expand Down Expand Up @@ -52,9 +53,35 @@ class Job(SyftObject):
status: JobStatus = JobStatus.CREATED
log_id: Optional[UID]
parent_job_id: Optional[UID]
n_iters: Optional[int] = 0
current_iter: Optional[int] = 0
creation_time: Optional[str] = str(datetime.now())

__attr_searchable__ = ["parent_job_id"]
__repr_attrs__ = ["id", "result", "resolved"]
__repr_attrs__ = ["id", "result", "resolved", "progress", "creation_time"]

@property
def progress(self) -> str:
if self.status == JobStatus.PROCESSING:
return_string = self.status
if self.n_iters > 0:
return_string += f": {self.current_iter}/{self.n_iters}"
if self.current_iter == self.n_iters:
return_string += " Almost done..."
elif self.current_iter > 0:
now = datetime.now()
time_passed = now - datetime.fromisoformat(self.creation_time)
time_per_checkpoint = time_passed / self.current_iter
remaining_checkpoints = self.n_iters - self.current_iter

# Probably need to divide by the number of consumers
remaining_time = remaining_checkpoints * time_per_checkpoint
remaining_time = str(remaining_time)[:-7]
return_string += f" Remaining time: {remaining_time}"
else:
return_string += " Estimating remaining time..."
return return_string
return self.status

def fetch(self) -> None:
api = APIRegistry.api_for(
Expand Down Expand Up @@ -82,6 +109,14 @@ def subjobs(self):
)
return api.services.job.get_subjobs(self.id)

@property
def owner(self):
api = APIRegistry.api_for(
node_uid=self.node_uid,
user_verify_key=self.syft_client_verify_key,
)
return api.services.user.get_current_user(self.id)

def logs(self, _print=True):
api = APIRegistry.api_for(
node_uid=self.node_uid,
Expand All @@ -107,15 +142,17 @@ def _coll_repr_(self) -> Dict[str, Any]:
logs = logs

if self.result is None:
result = ""
pass
else:
result = str(self.result.syft_action_data)
str(self.result.syft_action_data)

return {
"status": self.status,
"logs": logs,
"result": result,
"has_parent": self.has_parent,
"progress": self.progress,
"creation date": self.creation_time[:-7],
# "logs": logs,
# "result": result,
"owner email": self.owner.email,
"parent_id": str(self.parent_job_id) if self.parent_job_id else "-",
"subjobs": len(subjobs),
}

Expand Down Expand Up @@ -185,12 +222,10 @@ def set_result(
item: Job,
add_permissions: Optional[List[ActionObjectPermission]] = None,
) -> Result[Optional[Job], str]:
if item.resolved:
valid = self.check_type(item, self.object_type)
if valid.is_err():
return SyftError(message=valid.err())
return super().update(credentials, item, add_permissions)
return None
valid = self.check_type(item, self.object_type)
if valid.is_err():
return SyftError(message=valid.err())
return super().update(credentials, item, add_permissions)

def set_placeholder(
self,
Expand Down
1 change: 0 additions & 1 deletion packages/syft/src/syft/service/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def partition_by_node(kwargs: Dict[str, Any]) -> Dict[str, UID]:
uid = v.id
if isinstance(v, Asset):
uid = v.action_id

if not isinstance(uid, UID):
raise Exception(f"Input {k} must have a UID not {type(v)}")

Expand Down
1 change: 0 additions & 1 deletion packages/syft/src/syft/service/queue/zmq_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def post_init(self):
self.thread = None

def _run(self):
print("ABCDEF", flush=True)
liveness = HEARTBEAT_LIVENESS
interval = INTERVAL_INIT
heartbeat_at = time.time() + HEARTBEAT_INTERVAL
Expand Down
4 changes: 2 additions & 2 deletions packages/syft/src/syft/store/blob_storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def read(self) -> Union[SyftObject, SyftError]:
else:
return self._read_data()

def _read_data(self, stream=False):
def _read_data(self, stream=False, chunk_size=512):
# relative
from ...client.api import APIRegistry

Expand All @@ -136,7 +136,7 @@ def _read_data(self, stream=False):
response.raise_for_status()
if self.type_ is BlobFileType:
if stream:
return response.iter_lines()
return response.iter_lines(chunk_size=chunk_size)
else:
return response.content
return deserialize(response.content, from_bytes=True)
Expand Down
5 changes: 3 additions & 2 deletions packages/syft/src/syft/store/blob_storage/seaweedfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ def write(self, data: BytesIO) -> Union[SyftSuccess, SyftError]:
etags = []

try:
no_lines = 0
for part_no, (byte_chunk, url) in enumerate(
zip(_byte_chunks(data, DEFAULT_CHUNK_SIZE), self.urls),
start=1,
):
no_lines += byte_chunk.count(b"\n")
if api is not None:
blob_url = api.connection.to_blob_route(
url.url_path, host=url.host_or_ip
Expand All @@ -94,8 +96,7 @@ def write(self, data: BytesIO) -> Union[SyftSuccess, SyftError]:
syft_client_verify_key=self.syft_client_verify_key,
)
return mark_write_complete_method(
etags=etags,
uid=self.blob_storage_entry_id,
etags=etags, uid=self.blob_storage_entry_id, no_lines=no_lines
)


Expand Down
Loading

0 comments on commit 395e192

Please sign in to comment.