diff --git a/dask_cuda/cli.py b/dask_cuda/cli.py index d454129a..ea90b96e 100644 --- a/dask_cuda/cli.py +++ b/dask_cuda/cli.py @@ -13,7 +13,7 @@ from distributed.utils import import_term from .cuda_worker import CUDAWorker -from .utils import print_cluster_config +from .utils import CommaSeparatedChoice, print_cluster_config logger = logging.getLogger(__name__) @@ -167,6 +167,7 @@ def cuda(): @click.option( "--set-rmm-allocator-for-libs", "rmm_allocator_external_lib_list", + type=CommaSeparatedChoice(["cupy", "torch"]), default=None, show_default=True, help=""" diff --git a/dask_cuda/cuda_worker.py b/dask_cuda/cuda_worker.py index 91128eb1..30c14450 100644 --- a/dask_cuda/cuda_worker.py +++ b/dask_cuda/cuda_worker.py @@ -203,9 +203,6 @@ def del_pid_file(): "processes set `CUDF_SPILL=on` as well. To disable this warning " "set `DASK_CUDF_SPILL_WARNING=False`." ) - - if rmm_allocator_external_lib_list is not None: - rmm_allocator_external_lib_list = [s.strip() for s in rmm_allocator_external_lib_list.split(',')] self.nannies = [ Nanny( diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py index bce6ed64..9d19c798 100644 --- a/dask_cuda/local_cuda_cluster.py +++ b/dask_cuda/local_cuda_cluster.py @@ -143,10 +143,10 @@ class LocalCUDACluster(LocalCluster): The asynchronous allocator requires CUDA Toolkit 11.2 or newer. It is also incompatible with RMM pools and managed memory. Trying to enable both will result in an exception. - rmm_allocator_external_lib_list: str or list or None, default None - Set RMM as the allocator for external libraries. Can be a comma-separated - string (like "torch,cupy"). - + rmm_allocator_external_lib_list: list or None, default None + List of external libraries for which to set RMM as the allocator. + Supported options are: ``["torch", "cupy"]``. If None, no external + libraries will use RMM as their allocator. rmm_release_threshold: int, str or None, default None When ``rmm.async is True`` and the pool size grows beyond this value, unused memory held by the pool will be released at the next synchronization point. @@ -289,12 +289,6 @@ def __init__( self.rmm_managed_memory = rmm_managed_memory self.rmm_async = rmm_async self.rmm_release_threshold = rmm_release_threshold - if rmm_allocator_external_lib_list is not None and isinstance( - rmm_allocator_external_lib_list, str - ): - rmm_allocator_external_lib_list = [ - s.strip() for s in rmm_allocator_external_lib_list.split(",") - ] self.rmm_allocator_external_lib_list = rmm_allocator_external_lib_list if rmm_pool_size is not None or rmm_managed_memory or rmm_async: diff --git a/dask_cuda/utils.py b/dask_cuda/utils.py index 2f92d94b..e7d7cdbb 100644 --- a/dask_cuda/utils.py +++ b/dask_cuda/utils.py @@ -7,8 +7,9 @@ from contextlib import suppress from functools import singledispatch from multiprocessing import cpu_count -from typing import Optional, Callable, Dict +from typing import Callable, Dict, Optional +import click import numpy as np import pynvml import toolz @@ -771,7 +772,7 @@ def enable_rmm_memory_for_library(lib_name: str) -> None: Enable RMM memory pool support for a specified third-party library. This function allows the given library to utilize RMM's memory pool if it supports - integration with RMM. The library name is passed as a string argument, and if the + integration with RMM. The library name is passed as a string argument, and if the library is compatible, its memory allocator will be configured to use RMM. Parameters @@ -794,7 +795,7 @@ def enable_rmm_memory_for_library(lib_name: str) -> None: } if lib_name not in setup_functions: - supported_libs = ', '.join(setup_functions.keys()) + supported_libs = ", ".join(setup_functions.keys()) raise ValueError( f"The library '{lib_name}' is not supported for RMM integration. " f"Supported libraries are: {supported_libs}." @@ -803,6 +804,7 @@ def enable_rmm_memory_for_library(lib_name: str) -> None: # Call the setup function for the specified library setup_functions[lib_name]() + def _setup_rmm_for_torch() -> None: try: import torch @@ -813,6 +815,7 @@ def _setup_rmm_for_torch() -> None: torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + def _setup_rmm_for_cupy() -> None: try: import cupy @@ -820,4 +823,15 @@ def _setup_rmm_for_cupy() -> None: raise ImportError("CuPy is not installed.") from e from rmm.allocators.cupy import rmm_cupy_allocator + cupy.cuda.set_allocator(rmm_cupy_allocator) + + +class CommaSeparatedChoice(click.Choice): + def convert(self, value, param, ctx): + values = [v.strip() for v in value.split(",")] + for v in values: + if v not in self.choices: + choices_str = ", ".join(f"'{c}'" for c in self.choices) + self.fail(f"invalid choice(s): {v}. (choices are: {choices_str})") + return values