From dfe02d8ec90f2e6382741aa12bde9dd6eb516251 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 24 Oct 2023 13:26:35 -0700 Subject: [PATCH] Fix `Nanny` subclass check --- dask_cuda/local_cuda_cluster.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index 82f59346e..d0ea92748 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -332,16 +332,17 @@ def __init__( enable_rdmacm=enable_rdmacm, ) - if worker_class is not None and log_spilling is True: - raise ValueError( - "Cannot enable `log_spilling` when `worker_class` is specified. If " - "logging is needed, ensure `worker_class` is a subclass of " - "`distributed.local_cuda_cluster.LoggedNanny` or a subclass of " - "`distributed.local_cuda_cluster.LoggedWorker`, and specify " - "`log_spilling=False`." - ) - if not isinstance(worker_class, Nanny): - worker_class = partial(Nanny, worker_class=worker_class) + if worker_class is not None: + if log_spilling is True: + raise ValueError( + "Cannot enable `log_spilling` when `worker_class` is specified. If " + "logging is needed, ensure `worker_class` is a subclass of " + "`distributed.local_cuda_cluster.LoggedNanny` or a subclass of " + "`distributed.local_cuda_cluster.LoggedWorker`, and specify " + "`log_spilling=False`." + ) + if not issubclass(worker_class, Nanny): + worker_class = partial(Nanny, worker_class=worker_class) self.pre_import = pre_import