diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index 82f59346e..401510ec8 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -340,7 +340,7 @@ def __init__( "`distributed.local_cuda_cluster.LoggedWorker`, and specify " "`log_spilling=False`." ) - if not isinstance(worker_class, Nanny): + if not issubclass(worker_class, Nanny): worker_class = partial(Nanny, worker_class=worker_class) self.pre_import = pre_import