diff --git a/dask_cuda/utils.py b/dask_cuda/utils.py index a155dc593..1e244bb31 100644 --- a/dask_cuda/utils.py +++ b/dask_cuda/utils.py @@ -18,7 +18,7 @@ import distributed # noqa: required for dask.config.get("distributed.comm.ucx") from dask.config import canonical_name from dask.utils import format_bytes, parse_bytes -from distributed import Worker, wait +from distributed import Worker, WorkerPlugin, wait from distributed.comm import parse_address try: @@ -32,7 +32,7 @@ def nvtx_annotate(message=None, color="blue", domain=None): yield -class CPUAffinity: +class CPUAffinity(WorkerPlugin): def __init__(self, cores): self.cores = cores @@ -40,7 +40,7 @@ def setup(self, worker=None): os.sched_setaffinity(0, self.cores) -class RMMSetup: +class RMMSetup(WorkerPlugin): def __init__( self, initial_pool_size, @@ -135,7 +135,7 @@ def setup(self, worker=None): rmm.mr.set_current_device_resource(rmm.mr.TrackingResourceAdaptor(mr)) -class PreImport: +class PreImport(WorkerPlugin): def __init__(self, libraries): if libraries is None: libraries = []