Skip to content

Commit

Permalink
Update plugins to inherit from WorkerPlugin
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau committed Sep 5, 2023
1 parent 171fd2c commit c8e0982
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions dask_cuda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import distributed # noqa: required for dask.config.get("distributed.comm.ucx")
from dask.config import canonical_name
from dask.utils import format_bytes, parse_bytes
from distributed import Worker, wait
from distributed import Worker, WorkerPlugin, wait
from distributed.comm import parse_address

try:
Expand All @@ -32,15 +32,15 @@ def nvtx_annotate(message=None, color="blue", domain=None):
yield


class CPUAffinity:
class CPUAffinity(WorkerPlugin):
def __init__(self, cores):
self.cores = cores

def setup(self, worker=None):
os.sched_setaffinity(0, self.cores)


class RMMSetup:
class RMMSetup(WorkerPlugin):
def __init__(
self,
initial_pool_size,
Expand Down Expand Up @@ -135,7 +135,7 @@ def setup(self, worker=None):
rmm.mr.set_current_device_resource(rmm.mr.TrackingResourceAdaptor(mr))


class PreImport:
class PreImport(WorkerPlugin):
def __init__(self, libraries):
if libraries is None:
libraries = []
Expand Down

0 comments on commit c8e0982

Please sign in to comment.