Skip to content

Commit

Permalink
Handle exceptions in Celery worker initialization, set GPU memory lim…
Browse files Browse the repository at this point in the history
…it based on environment variable and add gpu limits to deployment configuration
  • Loading branch information
devxpy committed Aug 12, 2024
1 parent b62238b commit a5e5fb1
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
13 changes: 11 additions & 2 deletions celeryconfig.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import traceback
import typing

from celery import Celery
from celery.exceptions import WorkerShutdown
from celery.signals import worker_init
from kombu import Queue

Expand Down Expand Up @@ -33,8 +35,15 @@ def setup_queues(
queue_prefix: str = os.environ.get("QUEUE_PREFIX", "gooey-gpu"),
):
def init(**kwargs):
for model_id in model_ids:
load_fn(model_id)
model_id = None
try:
for model_id in model_ids:
load_fn(model_id)
except:
# for some reason, celery seems to swallow exceptions in init
print(f"Error loading {model_id}:")
traceback.print_exc()
raise WorkerShutdown()

init_fns.append(init)

Expand Down
4 changes: 4 additions & 0 deletions chart/templates/deployment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ spec:
value: "{{ $value }}"
{{- end }}
{{- end }}
{{- range $name, $value := .limits }}
- name: "RESOURCE_LIMITS_{{ $name | upper }}"
value: "{{ $value }}"
{{- end }}
livenessProbe:
exec:
command: [ "bash", "-c", "celery inspect ping -d celery@$HOSTNAME" ]
Expand Down
11 changes: 11 additions & 0 deletions gooey_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@
or "/root/.cache/gooey-gpu/checkpoints"
)

try:
gpu_limit_gib = float(os.environ["RESOURCE_LIMITS_GPU"].removesuffix("Gi"))
except (KeyError, ValueError):
print("RESOURCE_LIMITS_GPU environment variable not set to a valid value.")
else:
total_mem_bytes = torch.cuda.mem_get_info()[1]
fraction = gpu_limit_gib * 1024**3 / total_mem_bytes
torch.cuda.set_per_process_memory_fraction(fraction)
print(f"GPU limit set to {gpu_limit_gib}Gi ({fraction:.2%})")


if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
Expand Down

0 comments on commit a5e5fb1

Please sign in to comment.