Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Scaling down k8s worker pool interrupts associated jobs #9278

Closed
wants to merge 11 commits into from
163 changes: 146 additions & 17 deletions notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
"# import os\n",
"# os.environ[\"ORCHESTRA_DEPLOYMENT_TYPE\"] = \"remote\"\n",
"# os.environ[\"DEV_MODE\"] = \"True\"\n",
"# os.environ[\"TEST_EXTERNAL_REGISTRY\"] = \"k3d-registry.localhost:5800\""
"# os.environ[\"TEST_EXTERNAL_REGISTRY\"] = \"k3d-registry.localhost:5800\"\n",
"# stdlib"
]
},
{
Expand Down Expand Up @@ -181,7 +182,7 @@
"id": "14",
"metadata": {},
"source": [
"##### Scale down"
"##### Give workers some long-running jobs\n"
]
},
{
Expand All @@ -190,6 +191,102 @@
"id": "15",
"metadata": {},
"outputs": [],
"source": [
"@sy.syft_function_single_use(worker_pool_name=default_worker_pool.name)\n",
"def wait_1000_seconds_1():\n",
" # stdlib\n",
" import time\n",
"\n",
" time.sleep(1000)\n",
"\n",
"\n",
"@sy.syft_function_single_use(worker_pool_name=default_worker_pool.name)\n",
"def wait_1000_seconds_2():\n",
" # stdlib\n",
" import time\n",
"\n",
" time.sleep(1000)\n",
"\n",
"\n",
"@sy.syft_function_single_use(worker_pool_name=default_worker_pool.name)\n",
"def wait_1000_seconds_3():\n",
" # stdlib\n",
" import time\n",
"\n",
" time.sleep(1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "20",
"metadata": {},
"outputs": [],
"source": [
"jobs = []\n",
"high_client.code.request_code_execution(wait_1000_seconds_1)\n",
"high_client.code.request_code_execution(wait_1000_seconds_2)\n",
"high_client.code.request_code_execution(wait_1000_seconds_3)\n",
"\n",
"assert len(list(high_client.requests)) == 3\n",
"for request in high_client.requests:\n",
" request.approve()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "21",
"metadata": {},
"outputs": [],
"source": [
"jobs = []\n",
"jobs.append(high_client.code.wait_1000_seconds_1(blocking=False))\n",
"jobs.append(high_client.code.wait_1000_seconds_2(blocking=False))\n",
"jobs.append(high_client.code.wait_1000_seconds_3(blocking=False))\n",
"\n",
"\n",
"assert len(list(high_client.jobs)) == 3"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "22",
"metadata": {},
"outputs": [],
"source": [
"# check that at least three workers have a job (since scaling down to 2)\n",
"# try 3 times with a 5 second sleep in case it takes time for the workers to accept the jobs\n",
"for _ in range(3):\n",
" worker_to_job_map = {}\n",
" syft_workers_ids = set()\n",
" for job in high_client.jobs:\n",
" if job.status == \"processing\":\n",
" syft_workers_ids.add(job.worker.id)\n",
" worker_to_job_map[job.worker.id] = job.id\n",
" print(worker_to_job_map)\n",
" if len(syft_workers_ids) < 3:\n",
" time.sleep(20)\n",
" else:\n",
" break\n",
"assert len(syft_workers_ids) >= 3"
]
},
{
"cell_type": "markdown",
"id": "23",
"metadata": {},
"source": [
"##### Scale down"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24",
"metadata": {},
"outputs": [],
"source": [
"# Scale down workers, this gracefully shutdowns the consumers\n",
"if environment == \"remote\":\n",
Expand All @@ -202,26 +299,41 @@
{
"cell_type": "code",
"execution_count": null,
"id": "16",
"id": "25",
"metadata": {},
"outputs": [],
"source": [
"if environment == \"remote\":\n",
"\n",
" def has_worker_scaled_down():\n",
" return (\n",
" worker_count_condition = (\n",
" high_client.api.worker_pool[default_worker_pool.name].max_count\n",
" == num_workers\n",
" )\n",
" current_worker_ids = {\n",
" worker.id\n",
" for worker in high_client.api.services.worker_pool[\n",
" default_worker_pool.name\n",
" ].workers\n",
" }\n",
" job_status_condition = [\n",
" job.status == \"interrupted\"\n",
" for job in high_client.jobs\n",
" if job.job_worker_id is not None\n",
" and job.job_worker_id not in current_worker_ids\n",
" ]\n",
"\n",
" jobs_on_old_workers_are_interrupted = all(job_status_condition)\n",
" return worker_count_condition and jobs_on_old_workers_are_interrupted\n",
"\n",
" worker_scale_timeout = Timeout(timeout_duration=20)\n",
" worker_scale_timeout = Timeout(timeout_duration=60)\n",
" worker_scale_timeout.run_with_timeout(has_worker_scaled_down)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "17",
"id": "26",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -234,7 +346,7 @@
},
{
"cell_type": "markdown",
"id": "18",
"id": "27",
"metadata": {},
"source": [
"#### Delete Worker Pool"
Expand All @@ -243,7 +355,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "19",
"id": "28",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -256,17 +368,29 @@
{
"cell_type": "code",
"execution_count": null,
"id": "20",
"id": "29",
"metadata": {},
"outputs": [],
"source": [
"with sy.raises(KeyError):\n",
" _ = high_client.api.services.worker_pool[default_worker_pool.name]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "30",
"metadata": {},
"outputs": [],
"source": [
"# check that all jobs are interrupted\n",
"# should be the case since the entire pool was deleted and all jobs were previously assigned\n",
"assert all(job.status == \"interrupted\" for job in high_client.jobs)"
]
},
{
"cell_type": "markdown",
"id": "21",
"id": "31",
"metadata": {},
"source": [
"#### Re-launch the default worker pool"
Expand All @@ -275,7 +399,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "22",
"id": "32",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -285,7 +409,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "23",
"id": "33",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -299,7 +423,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "24",
"id": "34",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -313,7 +437,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "25",
"id": "35",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -323,7 +447,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "26",
"id": "36",
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -333,13 +457,18 @@
{
"cell_type": "code",
"execution_count": null,
"id": "27",
"id": "37",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "syft",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
Expand All @@ -350,7 +479,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
"version": "3.12.4"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/custom_worker/runner_k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

JSONPATH_AVAILABLE_REPLICAS = "{.status.availableReplicas}"
CREATE_POOL_TIMEOUT_SEC = 380
SCALE_POOL_TIMEOUT_SEC = 60
SCALE_POOL_TIMEOUT_SEC = 120


class KubernetesRunner:
Expand Down
9 changes: 7 additions & 2 deletions packages/syft/src/syft/service/job/job_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,15 @@ def _kill(self, context: AuthedServiceContext, job: Job) -> SyftSuccess:
results.append(res)

# wait for job and subjobs to be killed by MonitorThread
wait_until(lambda: job.fetched_status == JobStatus.INTERRUPTED)

wait_until(
lambda: self.get(context, uid=job.id).status == JobStatus.INTERRUPTED
)
subjob_uids = [subjob.id for subjob in self.get_subjobs(context, uid=job.id)]
wait_until(
lambda: all(
subjob.fetched_status == JobStatus.INTERRUPTED for subjob in job.subjobs
self.get(context, uid=subjob_id).status == JobStatus.INTERRUPTED
for subjob_id in subjob_uids
)
)

Expand Down
5 changes: 4 additions & 1 deletion packages/syft/src/syft/service/queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,10 @@ def monitor(self) -> None:
).unwrap()
if job and job.status == JobStatus.TERMINATING:
self.terminate(job)
for subjob in job.subjobs:
subjobs = self.worker.job_stash.get_by_parent_id(
self.credentials, job.id
).unwrap()
for subjob in subjobs:
self.terminate(subjob)

self.queue_item.status = Status.INTERRUPTED
Expand Down
1 change: 1 addition & 0 deletions packages/syft/src/syft/service/queue/zmq_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def _set_worker_job(self, job_id: UID | None) -> None:
credentials=self.worker_stash.partition.root_verify_key,
worker_uid=self.syft_worker_id,
consumer_state=consumer_state,
job_id=job_id,
)
if res.is_err():
logger.error(
Expand Down
9 changes: 5 additions & 4 deletions packages/syft/src/syft/service/queue/zmq_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,11 @@ def purge_workers(self) -> None:
self.delete_worker(worker, syft_worker.to_be_deleted)

# relative

self.auth_context.server.services.worker._delete(
self.auth_context, syft_worker
)
# if worker has expired, then delete it. Otherwise, it should be handled by the monitor thread
# should also delete if monitor thread is not alive
""" self.auth_context.server.services.worker.delete(
self.auth_context, syft_worker, force=True
) """

def update_consumer_state_for_worker(
self, syft_worker_id: UID, consumer_state: ConsumerState
Expand Down
Loading
Loading