Skip to content

Commit

Permalink
Address Peters review
Browse files Browse the repository at this point in the history
Signed-off-by: Vibhu Jawa <[email protected]>
  • Loading branch information
VibhuJawa committed Oct 8, 2024
1 parent d046478 commit 35647ae
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 17 deletions.
3 changes: 2 additions & 1 deletion dask_cuda/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from distributed.utils import import_term

from .cuda_worker import CUDAWorker
from .utils import print_cluster_config
from .utils import CommaSeparatedChoice, print_cluster_config

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -167,6 +167,7 @@ def cuda():
@click.option(
"--set-rmm-allocator-for-libs",
"rmm_allocator_external_lib_list",
type=CommaSeparatedChoice(["cupy", "torch"]),
default=None,
show_default=True,
help="""
Expand Down
3 changes: 0 additions & 3 deletions dask_cuda/cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,6 @@ def del_pid_file():
"processes set `CUDF_SPILL=on` as well. To disable this warning "
"set `DASK_CUDF_SPILL_WARNING=False`."
)

if rmm_allocator_external_lib_list is not None:
rmm_allocator_external_lib_list = [s.strip() for s in rmm_allocator_external_lib_list.split(',')]

self.nannies = [
Nanny(
Expand Down
14 changes: 4 additions & 10 deletions dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,10 @@ class LocalCUDACluster(LocalCluster):
The asynchronous allocator requires CUDA Toolkit 11.2 or newer. It is also
incompatible with RMM pools and managed memory. Trying to enable both will
result in an exception.
rmm_allocator_external_lib_list: str or list or None, default None
Set RMM as the allocator for external libraries. Can be a comma-separated
string (like "torch,cupy").
rmm_allocator_external_lib_list: list or None, default None
List of external libraries for which to set RMM as the allocator.
Supported options are: ``["torch", "cupy"]``. If None, no external
libraries will use RMM as their allocator.
rmm_release_threshold: int, str or None, default None
When ``rmm.async is True`` and the pool size grows beyond this value, unused
memory held by the pool will be released at the next synchronization point.
Expand Down Expand Up @@ -289,12 +289,6 @@ def __init__(
self.rmm_managed_memory = rmm_managed_memory
self.rmm_async = rmm_async
self.rmm_release_threshold = rmm_release_threshold
if rmm_allocator_external_lib_list is not None and isinstance(
rmm_allocator_external_lib_list, str
):
rmm_allocator_external_lib_list = [
s.strip() for s in rmm_allocator_external_lib_list.split(",")
]
self.rmm_allocator_external_lib_list = rmm_allocator_external_lib_list

if rmm_pool_size is not None or rmm_managed_memory or rmm_async:
Expand Down
20 changes: 17 additions & 3 deletions dask_cuda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from contextlib import suppress
from functools import singledispatch
from multiprocessing import cpu_count
from typing import Optional, Callable, Dict
from typing import Callable, Dict, Optional

import click
import numpy as np
import pynvml
import toolz
Expand Down Expand Up @@ -771,7 +772,7 @@ def enable_rmm_memory_for_library(lib_name: str) -> None:
Enable RMM memory pool support for a specified third-party library.
This function allows the given library to utilize RMM's memory pool if it supports
integration with RMM. The library name is passed as a string argument, and if the
integration with RMM. The library name is passed as a string argument, and if the
library is compatible, its memory allocator will be configured to use RMM.
Parameters
Expand All @@ -794,7 +795,7 @@ def enable_rmm_memory_for_library(lib_name: str) -> None:
}

if lib_name not in setup_functions:
supported_libs = ', '.join(setup_functions.keys())
supported_libs = ", ".join(setup_functions.keys())
raise ValueError(
f"The library '{lib_name}' is not supported for RMM integration. "
f"Supported libraries are: {supported_libs}."
Expand All @@ -803,6 +804,7 @@ def enable_rmm_memory_for_library(lib_name: str) -> None:
# Call the setup function for the specified library
setup_functions[lib_name]()


def _setup_rmm_for_torch() -> None:
try:
import torch
Expand All @@ -813,11 +815,23 @@ def _setup_rmm_for_torch() -> None:

torch.cuda.memory.change_current_allocator(rmm_torch_allocator)


def _setup_rmm_for_cupy() -> None:
try:
import cupy
except ImportError as e:
raise ImportError("CuPy is not installed.") from e

from rmm.allocators.cupy import rmm_cupy_allocator

cupy.cuda.set_allocator(rmm_cupy_allocator)


class CommaSeparatedChoice(click.Choice):
def convert(self, value, param, ctx):
values = [v.strip() for v in value.split(",")]
for v in values:
if v not in self.choices:
choices_str = ", ".join(f"'{c}'" for c in self.choices)
self.fail(f"invalid choice(s): {v}. (choices are: {choices_str})")
return values

0 comments on commit 35647ae

Please sign in to comment.