diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index e84dce0c7..57fe85439 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -2,7 +2,6 @@ import logging import os import warnings -from functools import partial import dask from distributed import LocalCluster, Nanny, Worker @@ -332,10 +331,14 @@ def __init__( enable_rdmacm=enable_rdmacm, ) - worker_class = partial( - LoggedNanny if log_spilling is True else Nanny, - worker_class=worker_class, - ) + if worker_class is not None and log_spilling is True: + raise ValueError( + "Cannot enable `log_spilling` when `worker_class` is specified. If " + "logging is needed, ensure `worker_class` is a subclass of " + "`distributed.local_cuda_cluster.LoggedNanny` or a subclass of " + "`distributed.local_cuda_cluster.LoggedWorker`, and specify " + "`log_spilling=False`." + ) self.pre_import = pre_import