From 1d6d5e82899b53f0d5c996501b92a07a2ee1a1df Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 23 Oct 2024 13:49:31 -0700 Subject: [PATCH] Add tests for GPU resource definition --- dask_cuda/tests/test_dask_cuda_worker.py | 42 ++++++++++++++++++++++ dask_cuda/tests/test_local_cuda_cluster.py | 20 +++++++++++ 2 files changed, 62 insertions(+) diff --git a/dask_cuda/tests/test_dask_cuda_worker.py b/dask_cuda/tests/test_dask_cuda_worker.py index 049fe85f..c32fa93d 100644 --- a/dask_cuda/tests/test_dask_cuda_worker.py +++ b/dask_cuda/tests/test_dask_cuda_worker.py @@ -594,3 +594,45 @@ def test_worker_cudf_spill_warning(enable_cudf_spill_warning): # noqa: F811 assert b"UserWarning: cuDF spilling is enabled" in ret.stderr else: assert b"UserWarning: cuDF spilling is enabled" not in ret.stderr + + +def test_worker_gpu_resource(loop): # noqa: F811 + with popen(["dask", "scheduler", "--port", "9369", "--no-dashboard"]): + with popen( + [ + "dask", + "cuda", + "worker", + "127.0.0.1:9369", + "--no-dashboard", + ] + ): + with Client("127.0.0.1:9369", loop=loop) as client: + assert wait_workers(client, n_gpus=get_n_gpus()) + + workers = client.scheduler_info()["workers"] + for v in workers.values(): + assert "GPU" in v["resources"] + assert v["resources"]["GPU"] == 1 + + +def test_worker_gpu_resource_user_defined(loop): # noqa: F811 + with popen(["dask", "scheduler", "--port", "9369", "--no-dashboard"]): + with popen( + [ + "dask", + "cuda", + "worker", + "127.0.0.1:9369", + "--resources", + "'GPU=55'", + "--no-dashboard", + ] + ): + with Client("127.0.0.1:9369", loop=loop) as client: + assert wait_workers(client, n_gpus=get_n_gpus()) + + workers = client.scheduler_info()["workers"] + for v in workers.values(): + assert "GPU" in v["resources"] + assert v["resources"]["GPU"] == 55 diff --git a/dask_cuda/tests/test_local_cuda_cluster.py b/dask_cuda/tests/test_local_cuda_cluster.py index b144d111..2b15973f 100644 --- a/dask_cuda/tests/test_local_cuda_cluster.py +++ b/dask_cuda/tests/test_local_cuda_cluster.py @@ -217,6 +217,26 @@ async def test_all_to_all(): assert all(all_data.count(i) == n_workers for i in all_data) +@gen_test(timeout=20) +async def test_worker_gpu_resource(): + async with LocalCUDACluster(asynchronous=True) as cluster: + async with Client(cluster, asynchronous=True) as client: + workers = client.scheduler_info()["workers"] + for v in workers.values(): + assert "GPU" in v["resources"] + assert v["resources"]["GPU"] == 1 + + +@gen_test(timeout=20) +async def test_worker_gpu_resource_user_defined(): + async with LocalCUDACluster(asynchronous=True, resources={"GPU": 55}) as cluster: + async with Client(cluster, asynchronous=True) as client: + workers = client.scheduler_info()["workers"] + for v in workers.values(): + assert "GPU" in v["resources"] + assert v["resources"]["GPU"] == 55 + + @gen_test(timeout=20) async def test_rmm_pool(): rmm = pytest.importorskip("rmm")