Skip to content

Commit

Permalink
Move test utils to dask_cuda.utils_test module
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Oct 23, 2023
1 parent 2e73bc4 commit 723fa7e
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
23 changes: 1 addition & 22 deletions dask_cuda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import distributed # noqa: required for dask.config.get("distributed.comm.ucx")
from dask.config import canonical_name
from dask.utils import format_bytes, parse_bytes
from distributed import Worker, wait
from distributed import wait
from distributed.comm import parse_address

try:
Expand Down Expand Up @@ -552,27 +552,6 @@ def _align(size, alignment_size):
return _align(int(device_memory_limit), alignment_size)


class MockWorker(Worker):
"""Mock Worker class preventing NVML from getting used by SystemMonitor.
By preventing the Worker from initializing NVML in the SystemMonitor, we can
mock test multiple devices in `CUDA_VISIBLE_DEVICES` behavior with single-GPU
machines.
"""

def __init__(self, *args, **kwargs):
distributed.diagnostics.nvml.device_get_count = MockWorker.device_get_count
self._device_get_count = distributed.diagnostics.nvml.device_get_count
super().__init__(*args, **kwargs)

def __del__(self):
distributed.diagnostics.nvml.device_get_count = self._device_get_count

@staticmethod
def device_get_count():
return 0


def get_gpu_uuid_from_index(device_index=0):
"""Get GPU UUID from CUDA device index.
Expand Down
23 changes: 23 additions & 0 deletions dask_cuda/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import distributed
from distributed import Worker


class MockWorker(Worker):
"""Mock Worker class preventing NVML from getting used by SystemMonitor.
By preventing the Worker from initializing NVML in the SystemMonitor, we can
mock test multiple devices in `CUDA_VISIBLE_DEVICES` behavior with single-GPU
machines.
"""

def __init__(self, *args, **kwargs):
distributed.diagnostics.nvml.device_get_count = MockWorker.device_get_count
self._device_get_count = distributed.diagnostics.nvml.device_get_count
super().__init__(*args, **kwargs)

def __del__(self):
distributed.diagnostics.nvml.device_get_count = self._device_get_count

@staticmethod
def device_get_count():
return 0

0 comments on commit 723fa7e

Please sign in to comment.