diff --git a/notebooks/helm/helm_syft.ipynb b/notebooks/helm/helm_syft.ipynb index f8326ed003b..8e46b05005c 100644 --- a/notebooks/helm/helm_syft.ipynb +++ b/notebooks/helm/helm_syft.ipynb @@ -32,14 +32,6 @@ "name": "stdout", "output_type": "stream", "text": [ - "INITIALIZING CONSUMER\n", - "ABCDEF\n", - "INITIALIZING CONSUMER\n", - "ABCDEF\n", - "INITIALIZING CONSUMER\n", - "ABCDEF\n", - "INITIALIZING CONSUMER\n", - "ABCDEF\n", "Logged into as \n" ] }, @@ -209,7 +201,7 @@ ], "source": [ "@sy.syft_function()\n", - "def compute_document_data_overlap(scenario_file, input_files, n):\n", + "def compute_document_data_overlap(domain, scenario_file, input_files, n):\n", " print(\"starting overlap computation\")\n", "\n", " from nltk import ngrams\n", @@ -272,6 +264,8 @@ " stats_key_to_reference_ids = defaultdict(set)\n", " print(\"computing overlap\")\n", " \n", + " domain.init_progress(len(input_files))\n", + " \n", " for input_file in input_files:\n", " for line in input_file.iter_lines():\n", " document = json.loads(line)[\"text\"]\n", @@ -285,6 +279,7 @@ " stats_key_to_input_ids[stats_key].add(id)\n", " elif part == \"references\":\n", " stats_key_to_reference_ids[stats_key].add(id)\n", + " domain.update_progress(1)\n", " print(\"done\")\n", " \n", " return stats_key_to_input_ids, stats_key_to_reference_ids, stats_key_counts" @@ -364,10 +359,10 @@ { "data": { "text/html": [ - "
SyftSuccess: Request 8a84551794dc49cda40f4d683fb19bdf changes applied

" + "
SyftSuccess: Request b9ab2ed69652452d82067aef2deea9a0 changes applied

" ], "text/plain": [ - "SyftSuccess: Request 8a84551794dc49cda40f4d683fb19bdf changes applied" + "SyftSuccess: Request b9ab2ed69652452d82067aef2deea9a0 changes applied" ] }, "execution_count": 13, @@ -406,13 +401,14 @@ "text/markdown": [ "```python\n", "class Job:\n", - " id: UID = 585159e4b6dd4599a3f0db6755e676f3\n", - " status: created\n", + " id: UID = 1101b23efa144e79b7623569dae4ffbe\n", + " status: completed\n", " has_parent: False\n", - " result: None\n", + " result: ActionDataEmpty UID: 249d1ff6f6674a8c979e37d25cc797ef \n", " logs:\n", "\n", "0 \n", + "JOB COMPLETED\n", " \n", "```" ], @@ -435,25 +431,26 @@ "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "LAUNCHING JOB compute_document_data_overlap\n", - "LAUNCHING JOB compute_document_data_overlap\n" - ] + "data": { + "text/plain": [ + "Pointer:\n", + "None" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" }, { "name": "stderr", "output_type": "stream", "text": [ - "FUNCTION LOG (80211968cab84fc0bbc0b6d7d986fa2f): starting overlap computation\n", - "FUNCTION LOG (ba3acba2f0b74644a1c61565187d21b4): starting overlap computation\n", - "FUNCTION LOG (ba3acba2f0b74644a1c61565187d21b4): preparing scenarios and creating indexes\n", - "FUNCTION LOG (ba3acba2f0b74644a1c61565187d21b4): computing overlap\n", - "FUNCTION LOG (ba3acba2f0b74644a1c61565187d21b4): preparing scenarios and creating indexes\n", - "FUNCTION LOG (ba3acba2f0b74644a1c61565187d21b4): computing overlap\n", - "FUNCTION LOG (ba3acba2f0b74644a1c61565187d21b4): done\n", - "FUNCTION LOG (ba3acba2f0b74644a1c61565187d21b4): done\n" + "FUNCTION LOG (96c55a32fccb4276a26f9b161282625d): preparing scenarios and creating indexes\n", + "FUNCTION LOG (4402cf5b2ceb42ceba08eafb85c25098): preparing scenarios and creating indexes\n", + "FUNCTION LOG (96c55a32fccb4276a26f9b161282625d): computing overlap\n", + "FUNCTION LOG (4402cf5b2ceb42ceba08eafb85c25098): computing overlap\n", + "FUNCTION LOG (96c55a32fccb4276a26f9b161282625d): done\n", + "FUNCTION LOG (4402cf5b2ceb42ceba08eafb85c25098): done\n" ] } ], @@ -672,7 +669,7 @@ " flex-grow: 0;\n", " }\n", "\n", - " .grid-tablea08d89bda1db4d7da3f282f3473ccb23 {\n", + " .grid-table81125880bac241d3aeb59e4c6555a78d {\n", " display:grid;\n", " grid-template-columns: 1fr repeat(24, 1fr);\n", " grid-template-rows: repeat(2, 1fr);\n", @@ -844,25 +841,25 @@ "
\n", "
\n", "
\n", - "
\n", - "
\n", + "
\n", " \n", "
\n", - " \n", + " \n", "
\n", - " \n", "
\n", "\n", - "

0

\n", + "

0

\n", "
\n", - "
\n", + "
\n", " \n", "
\n", - "
\n", + "
\n", " \n", "
\n", "
\n", @@ -1100,9 +1105,6 @@ "starting overlap computation\n", "preparing scenarios and creating indexes\n", "computing overlap\n", - "preparing scenarios and creating indexes\n", - "done\n", - "computing overlap\n", "done\n", "\n" ] @@ -1168,7 +1170,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -1186,7 +1188,7 @@ " 'anatomy_test_5': 135}))" ] }, - "execution_count": 21, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -1204,7 +1206,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -1255,7 +1257,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 24, "metadata": {}, "outputs": [ { diff --git a/notebooks/helm/new_policy.ipynb b/notebooks/helm/new_policy.ipynb new file mode 100644 index 00000000000..1c3a50ff9c0 --- /dev/null +++ b/notebooks/helm/new_policy.ipynb @@ -0,0 +1,965 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "kj/filesystem-disk-unix.c++:1703: warning: PWD environment variable doesn't match current directory; pwd = /home/teo/OpenMined/PySyft\n" + ] + } + ], + "source": [ + "import syft as sy" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Logged into as \n" + ] + }, + { + "data": { + "text/html": [ + "
SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`.

" + ], + "text/plain": [ + "SyftWarning: You are using a default password. Please change the password using `[your_client].me.set_password([new_password])`." + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "node = sy.orchestra.launch(name=\"test-domain-helm2\", dev_mode=True, reset=True, n_consumers=30,\n", + " create_producer=True)\n", + "client = node.login(email=\"info@openmined.org\", password=\"changethis\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Question 1: What type of container do we want for data already on the server?\n", + "\n", + "obj_1 = sy.ActionObject.from_obj(1)\n", + "ptr_1 = obj_1.send(client)\n", + "\n", + "obj_2 = sy.ActionObject.from_obj(2)\n", + "ptr_2 = obj_2.send(client)\n", + "\n", + "obj_3 = sy.ActionObject.from_obj(3)\n", + "ptr_3 = obj_3.send(client)\n", + "\n", + "# Option 1: ActionObjects inside ActionObjects\n", + "# \n", + "# Pros: very versatile, could work with data from other domains out of the box\n", + "# Cons: might not feel intuitive to the user, will need to change the way we work with\n", + "# ActionObjects in a lot of different places in the codebase\n", + "list = sy.ActionObject.from_obj([ptr_1, ptr_2, ptr_3])\n", + "list_ptr = list.send(client)\n", + "\n", + "# Option 2: Create new ActionObjects from the same data\n", + "# Will require us to do some value based verification on different objects\n", + "# \n", + "# Pros: Easier abstraction for the user\n", + "# Cons: Value based verification sounds like an attack vector\n", + "# as it can provide a free Oracle to an attacker\n", + "list = sy.ActionObject.from_list([ptr_1, ptr_2, ptr_3]) # on the server will do ActionObject.from_obj([1,2,3])\n", + "list_ptr = list.send(client)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
SyftSuccess: Syft function 'func' successfully created. To add a code request, please create a project using `project = syft.Project(...)`, then use command `project.create_code_request`.

" + ], + "text/plain": [ + "SyftSuccess: Syft function 'func' successfully created. To add a code request, please create a project using `project = syft.Project(...)`, then use command `project.create_code_request`." + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from syft.service.policy.policy import OutputPolicyExecuteCount\n", + "\n", + "# Question 2: What should the UX be for ExecuteOncePerCombination?\n", + "# \n", + "# Right now I have worked on using the first option from the previous question\n", + "# and using on the fly created lists. We can break this question into more specific ones:\n", + "#\n", + "# Sub-Question 1: What should we pass for each argument? Should the list be already on the server?\n", + "# Or can it be defined by the data scientist? \n", + "# Could it be made of data outside the domain?\n", + "#\n", + "# Sub-Question 2: Will anything change if instead of data we talk about files?\n", + "# The final use case actually will iterate for SyftFiles, so can this affect the UX?\n", + "#\n", + "\n", + "@sy.syft_function(input_policy=sy.ExecuteOncePerCombination(\n", + " x=[ptr_1, ptr_2, ptr_3],\n", + " y=[ptr_1, ptr_2, ptr_3],\n", + " z=[ptr_1, ptr_2, ptr_3],\n", + " ),\n", + " output_policy=OutputPolicyExecuteCount(limit=27))\n", + "def func(x, y, z):\n", + " return x, y, z" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "request = client.code.submit(func)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
SyftSuccess: Syft function 'main_func' successfully created. To add a code request, please create a project using `project = syft.Project(...)`, then use command `project.create_code_request`.

" + ], + "text/plain": [ + "SyftSuccess: Syft function 'main_func' successfully created. To add a code request, please create a project using `project = syft.Project(...)`, then use command `project.create_code_request`." + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "@sy.syft_function_single_use(list=list_ptr)\n", + "def main_func(domain, list):\n", + " jobs = []\n", + " print(\"start\")\n", + " # domain.init_progress(27)\n", + " for x in list:\n", + " for y in list:\n", + " for z in list:\n", + " print(x,y,z)\n", + " # domain.progress()\n", + " batch_job = domain.launch_job(func, x=x, y=y, z=z)\n", + " jobs.append(batch_job)\n", + " \n", + " print(\"done\")\n", + " \n", + " return None" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Request approved for domain test-domain-helm2\n" + ] + }, + { + "data": { + "text/html": [ + "
SyftSuccess: Request 51fa624adc3d47e7a8dc97886df8dfdc changes applied

" + ], + "text/plain": [ + "SyftSuccess: Request 51fa624adc3d47e7a8dc97886df8dfdc changes applied" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "client.code.request_code_execution(main_func)\n", + "client.requests[-1].approve()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "job = client.code.main_func(list=list_ptr, blocking=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/markdown": [ + "```python\n", + "class Job:\n", + " id: UID = 02936696f4a64aac98af478b04decb3d\n", + " status: JobStatus.CREATED\n", + " has_parent: False\n", + " result: None\n", + " logs:\n", + "\n", + "0 \n", + " \n", + "```" + ], + "text/plain": [ + "syft.service.job.job_stash.Job" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "job" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "jobs = client.jobs" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "
\n", + "
\n", + "
\n", + "

Job List

\n", + "
\n", + "\n", + "
\n", + "
\n", + "
\n", + "
\n", + "
\n", + " \n", + "
\n", + " \n", + "
\n", + " \n", + "
\n", + "\n", + "

0

\n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + " \n", + "
\n", + "
\n", + "
\n", + " \n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jobs" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# jobs[2].parent_job_id" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "job_tree = {}\n", + "for job in jobs:\n", + " if job.parent_job_id in job_tree:\n", + " job_tree[job.parent_job_id].append(job)\n", + " else:\n", + " job_tree[job.parent_job_id] = [job]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "├─ 02936696f4a64aac98af478b04decb3d\n" + ] + } + ], + "source": [ + "def recursive_print(parent_job, tab_space = 0):\n", + " lines = \"─\"\n", + " if parent_job.id in job_tree:\n", + " for job in job_tree[parent_job.id]:\n", + " print(f\"├─{lines * 2}\", job.id)\n", + " recursive_print(job, tab_space=tab_space+2)\n", + "\n", + "for job in jobs:\n", + " if not job.has_parent:\n", + " print(\"├─\", job.id)\n", + " recursive_print(job, tab_space=2)\n", + " " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "syft_3.11", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index f6fe457d162..1924cf96dbb 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -560,6 +560,12 @@ def exchange_route(self, client: Self) -> Union[SyftSuccess, SyftError]: return result + @property + def jobs(self) -> Optional[APIModule]: + if self.api.has_service("job"): + return self.api.services.job + return None + @property def users(self) -> Optional[APIModule]: if self.api.has_service("user"): diff --git a/packages/syft/src/syft/node/node.py b/packages/syft/src/syft/node/node.py index 376a576215d..694f0ef0c1d 100644 --- a/packages/syft/src/syft/node/node.py +++ b/packages/syft/src/syft/node/node.py @@ -423,7 +423,6 @@ def init_queue_manager(self, queue_config: Optional[QueueConfig]): for _ in range(queue_config.client_config.n_consumers): if address is None: raise ValueError("address unknown for consumers") - print("INITIALIZING CONSUMER") consumer = self.queue_manager.create_consumer( message_handler, address=address ) diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index cee5a838fab..6917deb2b25 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -181,9 +181,12 @@ def _user_code_execute( code_item: UserCode, kwargs: Dict[str, Any], ) -> Result[ActionObjectPointer, Err]: - filtered_kwargs = code_item.input_policy.filter_kwargs( + input_policy = code_item.input_policy + filtered_kwargs = input_policy.filter_kwargs( kwargs=kwargs, context=context, code_item_id=code_item.id ) + # update input policy to track any input state + code_item.input_policy = input_policy expected_input_kwargs = set() for _inp_kwarg in code_item.input_policy.inputs.values(): diff --git a/packages/syft/src/syft/service/blob_storage/service.py b/packages/syft/src/syft/service/blob_storage/service.py index d7d6e8e7830..01cde1982c1 100644 --- a/packages/syft/src/syft/service/blob_storage/service.py +++ b/packages/syft/src/syft/service/blob_storage/service.py @@ -149,6 +149,7 @@ def mark_write_complete( context: AuthedServiceContext, uid: UID, etags: List, + no_lines: Optional[int] = 0, ) -> Union[SyftError, SyftSuccess]: result = self.stash.get_by_uid( credentials=context.credentials, @@ -162,6 +163,14 @@ def mark_write_complete( if obj is None: return SyftError(message=f"No blob storage entry exists for uid: {uid}") + obj.no_lines = no_lines + result = self.stash.update( + credentials=context.credentials, + obj=obj, + ) + if result.is_err(): + return SyftError(message=f"{result.err()}") + with context.node.blob_storage_client.connect() as conn: result = conn.complete_multipart_upload(obj, etags) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index b535dd7c1a3..aed08cbf9e6 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -893,15 +893,32 @@ def execute_byte_code( original_print = __builtin__.print class LocalDomainClient: - def __init__(self): - pass + def __init__(self, context): + self.context = context + + def init_progress(self, n_iters): + if self.context.job is not None: + node = self.context.node + job_service = node.get_service("jobservice") + job = self.context.job + job.current_iter = 0 + job.n_iters = n_iters + job_service.update(self.context, job) + + def update_progress(self, n=1): + if self.context.job is not None: + node = self.context.node + job_service = node.get_service("jobservice") + job = self.context.job + job.current_iter += n + job_service.update(self.context, job) def launch_job(self, func: UserCode, **kwargs): # relative from ... import UID # get reference to node (TODO) - node = context.node + node = self.context.node action_service = node.get_service("actionservice") user_service = node.get_service("userservice") user_code_service = node.get_service("usercodeservice") @@ -913,7 +930,7 @@ def launch_job(self, func: UserCode, **kwargs): kw2id = {} for k, v in kwargs.items(): value = ActionObject.from_obj(v) - ptr = action_service.set(context, value) + ptr = action_service.set(self.context, value) ptr = ptr.ok() kw2id[k] = ptr.id @@ -931,7 +948,7 @@ def launch_job(self, func: UserCode, **kwargs): # TODO: throw exception for enclaves request = user_code_service._request_code_execution_inner( - context, new_user_code + self.context, new_user_code ).ok() admin_context = AuthedServiceContext( node=node, @@ -954,16 +971,16 @@ def launch_job(self, func: UserCode, **kwargs): original_print(f"LAUNCHING JOB {func.service_func_name}") job = node.add_api_call_to_queue( - api_call, parent_job_id=context.job_id + api_call, parent_job_id=self.context.job_id ) # set api in global scope to enable using .get(), .wait()) user_signing_key = [ x.signing_key for x in user_service.stash.partition.data.values() - if x.verify_key == context.credentials + if x.verify_key == self.context.credentials ][0] - user_api = node.get_api(context.credentials) + user_api = node.get_api(self.context.credentials) user_api.signing_key = user_signing_key # We hardcode a python connection here since we have access to the node # TODO: this is not secure @@ -971,7 +988,7 @@ def launch_job(self, func: UserCode, **kwargs): APIRegistry.set_api_for( node_uid=node.id, - user_verify_key=context.credentials, + user_verify_key=self.context.credentials, api=user_api, ) @@ -981,6 +998,8 @@ def launch_job(self, func: UserCode, **kwargs): raise ValueError(f"error while launching job:\n{e}") if context.job is not None: + job_id = context.job_id + log_id = context.job.log_id def print(*args, sep=" ", end="\n"): def to_str(arg: Any) -> str: @@ -997,11 +1016,9 @@ def to_str(arg: Any) -> str: new_args = [to_str(arg) for arg in args] new_str = sep.join(new_args) + end log_service = context.node.get_service("LogService") - log_service.append( - context=context, uid=context.job.log_id, new_str=new_str - ) + log_service.append(context=context, uid=log_id, new_str=new_str) return __builtin__.print( - f"FUNCTION LOG ({context.job.log_id}):", + f"FUNCTION LOG ({job_id}):", *new_args, end=end, sep=sep, @@ -1012,7 +1029,7 @@ def to_str(arg: Any) -> str: print = original_print if code_item.uses_domain: - kwargs["domain"] = LocalDomainClient() + kwargs["domain"] = LocalDomainClient(context=context) stdout = StringIO() stderr = StringIO() @@ -1020,16 +1037,17 @@ def to_str(arg: Any) -> str: # statisfy lint checker result = None - exec(code_item.byte_code) # nosec _locals = locals() + _globals = {} user_code_service = context.node.get_service("usercodeservice") for user_code in user_code_service.stash.get_all(context.credentials).ok(): - globals()[user_code.service_func_name] = user_code - globals()["print"] = print + _globals[user_code.service_func_name] = user_code + _globals["print"] = print + exec(code_item.parsed_code, _globals, locals()) # nosec evil_string = f"{code_item.unique_func_name}(**kwargs)" - result = eval(evil_string, None, _locals) # nosec + result = eval(evil_string, _globals, _locals) # nosec # reset print print = original_print diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index 897ca897293..5f34722e96c 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -9,6 +9,7 @@ from ...util.telemetry import instrument from ..context import AuthedServiceContext from ..response import SyftError +from ..response import SyftSuccess from ..service import AbstractService from ..service import service_method from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL @@ -41,6 +42,32 @@ def get( res = res.ok() return res + @service_method( + path="job.get_all", + name="get_all", + ) + def get_all(self, context: AuthedServiceContext) -> Union[List[Job], SyftError]: + res = self.stash.get_all(context.credentials) + if res.is_err(): + return SyftError(message=res.err()) + else: + res = res.ok() + return res + + @service_method( + path="job.update", + name="update", + roles=DATA_SCIENTIST_ROLE_LEVEL, + ) + def update( + self, context: AuthedServiceContext, job: Job + ) -> Union[SyftSuccess, SyftError]: + res = self.stash.update(context.credentials, obj=job) + if res.is_err(): + return SyftError(message=res.err()) + res = res.ok() + return SyftSuccess(message="Great Success!") + @service_method( path="job.get_subjobs", name="get_subjobs", diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 443bc0779fd..607e7964b0e 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -1,4 +1,5 @@ # stdlib +from datetime import datetime from enum import Enum from typing import Any from typing import Dict @@ -52,9 +53,35 @@ class Job(SyftObject): status: JobStatus = JobStatus.CREATED log_id: Optional[UID] parent_job_id: Optional[UID] + n_iters: Optional[int] = 0 + current_iter: Optional[int] = 0 + creation_time: Optional[str] = str(datetime.now()) __attr_searchable__ = ["parent_job_id"] - __repr_attrs__ = ["id", "result", "resolved"] + __repr_attrs__ = ["id", "result", "resolved", "progress", "creation_time"] + + @property + def progress(self) -> str: + if self.status == JobStatus.PROCESSING: + return_string = self.status + if self.n_iters > 0: + return_string += f": {self.current_iter}/{self.n_iters}" + if self.current_iter == self.n_iters: + return_string += " Almost done..." + elif self.current_iter > 0: + now = datetime.now() + time_passed = now - datetime.fromisoformat(self.creation_time) + time_per_checkpoint = time_passed / self.current_iter + remaining_checkpoints = self.n_iters - self.current_iter + + # Probably need to divide by the number of consumers + remaining_time = remaining_checkpoints * time_per_checkpoint + remaining_time = str(remaining_time)[:-7] + return_string += f" Remaining time: {remaining_time}" + else: + return_string += " Estimating remaining time..." + return return_string + return self.status def fetch(self) -> None: api = APIRegistry.api_for( @@ -82,6 +109,14 @@ def subjobs(self): ) return api.services.job.get_subjobs(self.id) + @property + def owner(self): + api = APIRegistry.api_for( + node_uid=self.node_uid, + user_verify_key=self.syft_client_verify_key, + ) + return api.services.user.get_current_user(self.id) + def logs(self, _print=True): api = APIRegistry.api_for( node_uid=self.node_uid, @@ -107,15 +142,17 @@ def _coll_repr_(self) -> Dict[str, Any]: logs = logs if self.result is None: - result = "" + pass else: - result = str(self.result.syft_action_data) + str(self.result.syft_action_data) return { - "status": self.status, - "logs": logs, - "result": result, - "has_parent": self.has_parent, + "progress": self.progress, + "creation date": self.creation_time[:-7], + # "logs": logs, + # "result": result, + "owner email": self.owner.email, + "parent_id": str(self.parent_job_id) if self.parent_job_id else "-", "subjobs": len(subjobs), } @@ -185,12 +222,10 @@ def set_result( item: Job, add_permissions: Optional[List[ActionObjectPermission]] = None, ) -> Result[Optional[Job], str]: - if item.resolved: - valid = self.check_type(item, self.object_type) - if valid.is_err(): - return SyftError(message=valid.err()) - return super().update(credentials, item, add_permissions) - return None + valid = self.check_type(item, self.object_type) + if valid.is_err(): + return SyftError(message=valid.err()) + return super().update(credentials, item, add_permissions) def set_placeholder( self, diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index 0ca405ed381..6f5842e95d8 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -146,7 +146,6 @@ def partition_by_node(kwargs: Dict[str, Any]) -> Dict[str, UID]: uid = v.id if isinstance(v, Asset): uid = v.action_id - if not isinstance(uid, UID): raise Exception(f"Input {k} must have a UID not {type(v)}") diff --git a/packages/syft/src/syft/service/queue/zmq_queue.py b/packages/syft/src/syft/service/queue/zmq_queue.py index a952a69ed33..748960ec3d4 100644 --- a/packages/syft/src/syft/service/queue/zmq_queue.py +++ b/packages/syft/src/syft/service/queue/zmq_queue.py @@ -219,7 +219,6 @@ def post_init(self): self.thread = None def _run(self): - print("ABCDEF", flush=True) liveness = HEARTBEAT_LIVENESS interval = INTERVAL_INIT heartbeat_at = time.time() + HEARTBEAT_INTERVAL diff --git a/packages/syft/src/syft/store/blob_storage/__init__.py b/packages/syft/src/syft/store/blob_storage/__init__.py index 4dee68d51c1..712014a0ab1 100644 --- a/packages/syft/src/syft/store/blob_storage/__init__.py +++ b/packages/syft/src/syft/store/blob_storage/__init__.py @@ -115,7 +115,7 @@ def read(self) -> Union[SyftObject, SyftError]: else: return self._read_data() - def _read_data(self, stream=False): + def _read_data(self, stream=False, chunk_size=512): # relative from ...client.api import APIRegistry @@ -136,7 +136,7 @@ def _read_data(self, stream=False): response.raise_for_status() if self.type_ is BlobFileType: if stream: - return response.iter_lines() + return response.iter_lines(chunk_size=chunk_size) else: return response.content return deserialize(response.content, from_bytes=True) diff --git a/packages/syft/src/syft/store/blob_storage/seaweedfs.py b/packages/syft/src/syft/store/blob_storage/seaweedfs.py index 71dd5ad6105..fb6d3106a99 100644 --- a/packages/syft/src/syft/store/blob_storage/seaweedfs.py +++ b/packages/syft/src/syft/store/blob_storage/seaweedfs.py @@ -69,10 +69,12 @@ def write(self, data: BytesIO) -> Union[SyftSuccess, SyftError]: etags = [] try: + no_lines = 0 for part_no, (byte_chunk, url) in enumerate( zip(_byte_chunks(data, DEFAULT_CHUNK_SIZE), self.urls), start=1, ): + no_lines += byte_chunk.count(b"\n") if api is not None: blob_url = api.connection.to_blob_route( url.url_path, host=url.host_or_ip @@ -94,8 +96,7 @@ def write(self, data: BytesIO) -> Union[SyftSuccess, SyftError]: syft_client_verify_key=self.syft_client_verify_key, ) return mark_write_complete_method( - etags=etags, - uid=self.blob_storage_entry_id, + etags=etags, uid=self.blob_storage_entry_id, no_lines=no_lines ) diff --git a/packages/syft/src/syft/types/blob_storage.py b/packages/syft/src/syft/types/blob_storage.py index 6a7578893e4..f90d0ac8559 100644 --- a/packages/syft/src/syft/types/blob_storage.py +++ b/packages/syft/src/syft/types/blob_storage.py @@ -1,7 +1,9 @@ # stdlib import mimetypes from pathlib import Path +from queue import Queue import sys +import threading from typing import Any from typing import ClassVar from typing import List @@ -37,16 +39,38 @@ class BlobFile(SyftObject): file_name: str syft_blob_storage_entry_id: Optional[UID] = None - def read(self, stream=False): + def read(self, stream=False, chunk_size=512, force=False): # get blob retrieval object from api + syft_blob_storage_entry_id read_method = from_api_or_context( "blob_storage.read", self.syft_node_location, self.syft_client_verify_key ) blob_retrieval_object = read_method(self.syft_blob_storage_entry_id) - return blob_retrieval_object._read_data(stream=stream) - - def iter_lines(self): - return self.read(stream=True) + return blob_retrieval_object._read_data(stream=stream, chunk_size=chunk_size) + + def _iter_lines(self, chunk_size=512): + """Synchronous version of the async iter_lines""" + return self.read(stream=True, chunk_size=chunk_size) + + def read_queue(self, queue, chunk_size): + for line in self._iter_lines(chunk_size=chunk_size): + queue.put(line) + # Put anything not a string at the end + queue.put(0) + + def iter_lines(self, chunk_size=512): + item_queue: Queue = Queue() + threading.Thread( + target=self.read_queue, + args=( + item_queue, + chunk_size, + ), + daemon=True, + ).start() + item = item_queue.get() + while item != 0: + yield item + item = item_queue.get() class BlobFileType(type): @@ -97,6 +121,7 @@ class BlobStorageEntry(SyftObject): type_: Optional[Type] mimetype: str = "bytes" file_size: int + no_lines: Optional[int] = 0 uploaded_by: SyftVerifyKey created_at: DateTime = DateTime.now() @@ -109,6 +134,7 @@ class BlobStorageMetadata(SyftObject): type_: Optional[Type[SyftObject]] mimetype: str = "bytes" file_size: int + no_lines: Optional[int] = 0 @serializable()