diff --git a/dask_cuda/benchmarks/common.py b/dask_cuda/benchmarks/common.py index e734f882c..0b417e7b3 100644 --- a/dask_cuda/benchmarks/common.py +++ b/dask_cuda/benchmarks/common.py @@ -121,6 +121,8 @@ def run(client: Client, args: Namespace, config: Config): args.type == "gpu", args.rmm_pool_size, args.disable_rmm_pool, + args.enable_rmm_async, + args.enable_rmm_managed, args.rmm_log_directory, args.enable_rmm_statistics, ) diff --git a/dask_cuda/benchmarks/utils.py b/dask_cuda/benchmarks/utils.py index 1d07df30c..a3d51066a 100644 --- a/dask_cuda/benchmarks/utils.py +++ b/dask_cuda/benchmarks/utils.py @@ -98,6 +98,16 @@ def parse_benchmark_args(description="Generic dask-cuda Benchmark", args_list=[] cluster_args.add_argument( "--disable-rmm-pool", action="store_true", help="Disable the RMM memory pool" ) + cluster_args.add_argument( + "--enable-rmm-managed", + action="store_true", + help="Enable RMM managed memory allocator", + ) + cluster_args.add_argument( + "--enable-rmm-async", + action="store_true", + help="Enable RMM async memory allocator (implies --disable-rmm-pool)", + ) cluster_args.add_argument( "--rmm-log-directory", default=None, @@ -346,6 +356,8 @@ def setup_memory_pool( dask_worker=None, pool_size=None, disable_pool=False, + rmm_async=False, + rmm_managed=False, log_directory=None, statistics=False, ): @@ -357,10 +369,13 @@ def setup_memory_pool( logging = log_directory is not None - if not disable_pool: + if rmm_async: + rmm.mr.set_current_device_resource(rmm.mr.CudaAsyncMemoryResource()) + cupy.cuda.set_allocator(rmm.rmm_cupy_allocator) + else: rmm.reinitialize( - pool_allocator=True, - devices=0, + pool_allocator=not disable_pool, + managed_memory=rmm_managed, initial_pool_size=pool_size, logging=logging, log_file_name=get_rmm_log_file_name(dask_worker, logging, log_directory), @@ -373,7 +388,14 @@ def setup_memory_pool( def setup_memory_pools( - client, is_gpu, pool_size, disable_pool, log_directory, statistics + client, + is_gpu, + pool_size, + disable_pool, + rmm_async, + rmm_managed, + log_directory, + statistics, ): if not is_gpu: return @@ -381,6 +403,8 @@ def setup_memory_pools( setup_memory_pool, pool_size=pool_size, disable_pool=disable_pool, + rmm_async=rmm_async, + rmm_managed=rmm_managed, log_directory=log_directory, statistics=statistics, ) @@ -390,6 +414,8 @@ def setup_memory_pools( setup_memory_pool, pool_size=1e9, disable_pool=disable_pool, + rmm_async=rmm_async, + rmm_managed=rmm_managed, log_directory=log_directory, statistics=statistics, )