diff --git a/notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb b/notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb index e46587d23e5..884b77501ab 100644 --- a/notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb +++ b/notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb @@ -10,7 +10,8 @@ "# import os\n", "# os.environ[\"ORCHESTRA_DEPLOYMENT_TYPE\"] = \"remote\"\n", "# os.environ[\"DEV_MODE\"] = \"True\"\n", - "# os.environ[\"TEST_EXTERNAL_REGISTRY\"] = \"k3d-registry.localhost:5800\"" + "# os.environ[\"TEST_EXTERNAL_REGISTRY\"] = \"k3d-registry.localhost:5800\"\n", + "# stdlib" ] }, { @@ -181,7 +182,7 @@ "id": "14", "metadata": {}, "source": [ - "##### Scale down" + "##### Give workers some long-running jobs\n" ] }, { @@ -190,6 +191,102 @@ "id": "15", "metadata": {}, "outputs": [], + "source": [ + "@sy.syft_function_single_use(worker_pool_name=default_worker_pool.name)\n", + "def wait_1000_seconds_1():\n", + " # stdlib\n", + " import time\n", + "\n", + " time.sleep(1000)\n", + "\n", + "\n", + "@sy.syft_function_single_use(worker_pool_name=default_worker_pool.name)\n", + "def wait_1000_seconds_2():\n", + " # stdlib\n", + " import time\n", + "\n", + " time.sleep(1000)\n", + "\n", + "\n", + "@sy.syft_function_single_use(worker_pool_name=default_worker_pool.name)\n", + "def wait_1000_seconds_3():\n", + " # stdlib\n", + " import time\n", + "\n", + " time.sleep(1000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20", + "metadata": {}, + "outputs": [], + "source": [ + "jobs = []\n", + "high_client.code.request_code_execution(wait_1000_seconds_1)\n", + "high_client.code.request_code_execution(wait_1000_seconds_2)\n", + "high_client.code.request_code_execution(wait_1000_seconds_3)\n", + "\n", + "assert len(list(high_client.requests)) == 3\n", + "for request in high_client.requests:\n", + " request.approve()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21", + "metadata": {}, + "outputs": [], + "source": [ + "jobs = []\n", + "jobs.append(high_client.code.wait_1000_seconds_1(blocking=False))\n", + "jobs.append(high_client.code.wait_1000_seconds_2(blocking=False))\n", + "jobs.append(high_client.code.wait_1000_seconds_3(blocking=False))\n", + "\n", + "\n", + "assert len(list(high_client.jobs)) == 3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "22", + "metadata": {}, + "outputs": [], + "source": [ + "# check that at least three workers have a job (since scaling down to 2)\n", + "# try 3 times with a 5 second sleep in case it takes time for the workers to accept the jobs\n", + "for _ in range(3):\n", + " worker_to_job_map = {}\n", + " syft_workers_ids = set()\n", + " for job in high_client.jobs:\n", + " if job.status == \"processing\":\n", + " syft_workers_ids.add(job.worker.id)\n", + " worker_to_job_map[job.worker.id] = job.id\n", + " print(worker_to_job_map)\n", + " if len(syft_workers_ids) < 3:\n", + " time.sleep(20)\n", + " else:\n", + " break\n", + "assert len(syft_workers_ids) >= 3" + ] + }, + { + "cell_type": "markdown", + "id": "23", + "metadata": {}, + "source": [ + "##### Scale down" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24", + "metadata": {}, + "outputs": [], "source": [ "# Scale down workers, this gracefully shutdowns the consumers\n", "if environment == \"remote\":\n", @@ -202,26 +299,41 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "25", "metadata": {}, "outputs": [], "source": [ "if environment == \"remote\":\n", "\n", " def has_worker_scaled_down():\n", - " return (\n", + " worker_count_condition = (\n", " high_client.api.worker_pool[default_worker_pool.name].max_count\n", " == num_workers\n", " )\n", + " current_worker_ids = {\n", + " worker.id\n", + " for worker in high_client.api.services.worker_pool[\n", + " default_worker_pool.name\n", + " ].workers\n", + " }\n", + " job_status_condition = [\n", + " job.status == \"interrupted\"\n", + " for job in high_client.jobs\n", + " if job.job_worker_id is not None\n", + " and job.job_worker_id not in current_worker_ids\n", + " ]\n", + "\n", + " jobs_on_old_workers_are_interrupted = all(job_status_condition)\n", + " return worker_count_condition and jobs_on_old_workers_are_interrupted\n", "\n", - " worker_scale_timeout = Timeout(timeout_duration=20)\n", + " worker_scale_timeout = Timeout(timeout_duration=60)\n", " worker_scale_timeout.run_with_timeout(has_worker_scaled_down)" ] }, { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "26", "metadata": {}, "outputs": [], "source": [ @@ -234,7 +346,7 @@ }, { "cell_type": "markdown", - "id": "18", + "id": "27", "metadata": {}, "source": [ "#### Delete Worker Pool" @@ -243,7 +355,7 @@ { "cell_type": "code", "execution_count": null, - "id": "19", + "id": "28", "metadata": {}, "outputs": [], "source": [ @@ -256,7 +368,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "29", "metadata": {}, "outputs": [], "source": [ @@ -264,9 +376,21 @@ " _ = high_client.api.services.worker_pool[default_worker_pool.name]" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "30", + "metadata": {}, + "outputs": [], + "source": [ + "# check that all jobs are interrupted\n", + "# should be the case since the entire pool was deleted and all jobs were previously assigned\n", + "assert all(job.status == \"interrupted\" for job in high_client.jobs)" + ] + }, { "cell_type": "markdown", - "id": "21", + "id": "31", "metadata": {}, "source": [ "#### Re-launch the default worker pool" @@ -275,7 +399,7 @@ { "cell_type": "code", "execution_count": null, - "id": "22", + "id": "32", "metadata": {}, "outputs": [], "source": [ @@ -285,7 +409,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "33", "metadata": {}, "outputs": [], "source": [ @@ -299,7 +423,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24", + "id": "34", "metadata": {}, "outputs": [], "source": [ @@ -313,7 +437,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "35", "metadata": {}, "outputs": [], "source": [ @@ -323,7 +447,7 @@ { "cell_type": "code", "execution_count": null, - "id": "26", + "id": "36", "metadata": {}, "outputs": [], "source": [ @@ -333,13 +457,18 @@ { "cell_type": "code", "execution_count": null, - "id": "27", + "id": "37", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { + "kernelspec": { + "display_name": "syft", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -350,7 +479,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.12.4" } }, "nbformat": 4, diff --git a/packages/syft/src/syft/custom_worker/runner_k8s.py b/packages/syft/src/syft/custom_worker/runner_k8s.py index e320bae8e94..f57f6071a88 100644 --- a/packages/syft/src/syft/custom_worker/runner_k8s.py +++ b/packages/syft/src/syft/custom_worker/runner_k8s.py @@ -14,7 +14,7 @@ JSONPATH_AVAILABLE_REPLICAS = "{.status.availableReplicas}" CREATE_POOL_TIMEOUT_SEC = 380 -SCALE_POOL_TIMEOUT_SEC = 60 +SCALE_POOL_TIMEOUT_SEC = 120 class KubernetesRunner: diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index fbbdfd7d856..9d8334280fd 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -162,10 +162,15 @@ def _kill(self, context: AuthedServiceContext, job: Job) -> SyftSuccess: results.append(res) # wait for job and subjobs to be killed by MonitorThread - wait_until(lambda: job.fetched_status == JobStatus.INTERRUPTED) + + wait_until( + lambda: self.get(context, uid=job.id).status == JobStatus.INTERRUPTED + ) + subjob_uids = [subjob.id for subjob in self.get_subjobs(context, uid=job.id)] wait_until( lambda: all( - subjob.fetched_status == JobStatus.INTERRUPTED for subjob in job.subjobs + self.get(context, uid=subjob_id).status == JobStatus.INTERRUPTED + for subjob_id in subjob_uids ) ) diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index ae170cad95b..fb7d9b068ca 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -75,7 +75,10 @@ def monitor(self) -> None: ).unwrap() if job and job.status == JobStatus.TERMINATING: self.terminate(job) - for subjob in job.subjobs: + subjobs = self.worker.job_stash.get_by_parent_id( + self.credentials, job.id + ).unwrap() + for subjob in subjobs: self.terminate(subjob) self.queue_item.status = Status.INTERRUPTED diff --git a/packages/syft/src/syft/service/queue/zmq_consumer.py b/packages/syft/src/syft/service/queue/zmq_consumer.py index f6993d6b032..2ec06916826 100644 --- a/packages/syft/src/syft/service/queue/zmq_consumer.py +++ b/packages/syft/src/syft/service/queue/zmq_consumer.py @@ -291,6 +291,7 @@ def _set_worker_job(self, job_id: UID | None) -> None: credentials=self.worker_stash.partition.root_verify_key, worker_uid=self.syft_worker_id, consumer_state=consumer_state, + job_id=job_id, ) if res.is_err(): logger.error( diff --git a/packages/syft/src/syft/service/queue/zmq_producer.py b/packages/syft/src/syft/service/queue/zmq_producer.py index 5cb5056f8a2..38112aa2654 100644 --- a/packages/syft/src/syft/service/queue/zmq_producer.py +++ b/packages/syft/src/syft/service/queue/zmq_producer.py @@ -264,10 +264,11 @@ def purge_workers(self) -> None: self.delete_worker(worker, syft_worker.to_be_deleted) # relative - - self.auth_context.server.services.worker._delete( - self.auth_context, syft_worker - ) + # if worker has expired, then delete it. Otherwise, it should be handled by the monitor thread + # should also delete if monitor thread is not alive + """ self.auth_context.server.services.worker.delete( + self.auth_context, syft_worker, force=True + ) """ def update_consumer_state_for_worker( self, syft_worker_id: UID, consumer_state: ConsumerState diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index 4ceced2bf26..8a344506183 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -449,28 +449,33 @@ def scale( ) else: # scale down at kubernetes control plane - runner = KubernetesRunner() - scale_kubernetes_pool( - runner, - pool_name=worker_pool.name, - replicas=number, - ).unwrap() # scale down removes the last "n" workers # workers to delete = len(workers) - number - workers_to_delete = worker_pool.worker_list[ - -(current_worker_count - number) : + workers = [ + worker.resolve_with_context(context=context).unwrap() + for worker in worker_pool.worker_list ] - worker_stash = context.server.services.worker.stash - # delete linkedobj workers + # get last "n" workers from pod list + runner = KubernetesRunner() + workers_to_delete = workers[ + -(current_worker_count - number) :] + worker_service = context.server.services.worker + + # update workers to to be deleted and wait for producer thread to call deletions for worker in workers_to_delete: - worker_stash.delete_by_uid( - credentials=context.credentials, - uid=worker.object_uid, - ).unwrap() + worker.to_delete = True - client_warning += "Scaling down workers doesn't kill the associated jobs. Please delete them manually." + worker_stash = context.server.services.worker.stash + worker_stash.update(context.credentials) + #worker_service.delete(context=context, uid=worker.id, force=True) + + scale_kubernetes_pool( + runner, + pool_name=worker_pool.name, + replicas=number, + ).unwrap() # update worker_pool worker_pool.max_count = number diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 625a88a46b4..0e717e8cb31 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -35,6 +35,7 @@ from .worker_pool import _get_worker_container from .worker_pool import _get_worker_container_status from .worker_stash import WorkerStash +from ..job.job_stash import JobStatus @serializable(canonical_name="WorkerService", version=1) @@ -144,11 +145,26 @@ def _delete( worker_pool = worker_pool_stash.get_by_name( credentials=context.credentials, pool_name=worker.worker_pool_name ).unwrap() + # remove the worker from the pool + try: + worker_linked_object = next( + obj for obj in worker_pool.worker_list if obj.object_uid == uid + ) + worker_pool.worker_list.remove(worker_linked_object) + except StopIteration: + pass + + # Delete worker from worker stash + self.stash.delete_by_uid(credentials=context.credentials, uid=uid).unwrap() + + # Update worker pool + + worker_pool_stash.update(context.credentials, obj=worker_pool).unwrap() if IN_KUBERNETES: # Kubernetes will only restart the worker NOT REMOVE IT runner = KubernetesRunner() - runner.delete_pod(pod_name=worker.name) + # runner.delete_pod(pod_name=worker.name) return SyftSuccess( # pod deletion is not supported in Kubernetes, removing and recreating the pod. message=( @@ -165,20 +181,7 @@ def _delete( # kill the in memory worker thread context.server.remove_consumer_with_id(syft_worker_id=worker.id) - # remove the worker from the pool - try: - worker_linked_object = next( - obj for obj in worker_pool.worker_list if obj.object_uid == uid - ) - worker_pool.worker_list.remove(worker_linked_object) - except StopIteration: - pass - - # Delete worker from worker stash - self.stash.delete_by_uid(credentials=context.credentials, uid=uid).unwrap() - - # Update worker pool - worker_pool_stash.update(context.credentials, obj=worker_pool).unwrap() + return SyftSuccess( message=f"Worker with id: {uid} deleted successfully from pool: {worker_pool.name}" diff --git a/packages/syft/src/syft/service/worker/worker_stash.py b/packages/syft/src/syft/service/worker/worker_stash.py index b2b059ffec5..302395d0028 100644 --- a/packages/syft/src/syft/service/worker/worker_stash.py +++ b/packages/syft/src/syft/service/worker/worker_stash.py @@ -67,8 +67,17 @@ def get_worker_by_name( @as_result(StashException, NotFoundException) def update_consumer_state( - self, credentials: SyftVerifyKey, worker_uid: UID, consumer_state: ConsumerState + self, + credentials: SyftVerifyKey, + worker_uid: UID, + consumer_state: ConsumerState, + job_id: UID | None = None, ) -> SyftWorker: worker = self.get_by_uid(credentials=credentials, uid=worker_uid).unwrap() worker.consumer_state = consumer_state + if job_id is not None: + worker.job_id = job_id + + if worker.consumer_state == ConsumerState.IDLE: + worker.job_id = None return self.update(credentials=credentials, obj=worker).unwrap()