diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index d1024ff69..1a15370b5 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -17,8 +17,6 @@ import dask_cuda from dask_cuda.explicit_comms import comms from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle -from dask_cuda.initialize import initialize -from dask_cuda.utils import get_ucx_config mp = mp.get_context("spawn") # type: ignore ucp = pytest.importorskip("ucp") @@ -32,14 +30,6 @@ async def my_rank(state, arg): def _test_local_cluster(protocol): - dask.config.update( - dask.config.global_config, - { - "distributed.comm.ucx": get_ucx_config(enable_tcp_over_ucx=True), - }, - priority="new", - ) - with LocalCluster( protocol=protocol, dashboard_address=None, @@ -106,15 +96,6 @@ def check_partitions(df, npartitions): def _test_dataframe_shuffle(backend, protocol, n_workers): if backend == "cudf": cudf = pytest.importorskip("cudf") - initialize(enable_tcp_over_ucx=True) - else: - dask.config.update( - dask.config.global_config, - { - "distributed.comm.ucx": get_ucx_config(enable_tcp_over_ucx=True), - }, - priority="new", - ) with LocalCluster( protocol=protocol, @@ -220,17 +201,6 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers): if backend == "cudf": cudf = pytest.importorskip("cudf") - initialize(enable_tcp_over_ucx=True) - else: - - dask.config.update( - dask.config.global_config, - { - "distributed.comm.ucx": get_ucx_config(enable_tcp_over_ucx=True), - }, - priority="new", - ) - with LocalCluster( protocol=protocol, dashboard_address=None, @@ -287,7 +257,6 @@ def _test_jit_unspill(protocol): threads_per_worker=1, jit_unspill=True, device_memory_limit="1B", - enable_tcp_over_ucx=True if protocol == "ucx" else False, ) as cluster: with Client(cluster): np.random.seed(42) diff --git a/dask_cuda/tests/test_local_cuda_cluster.py b/dask_cuda/tests/test_local_cuda_cluster.py index f2e48783c..e087fb70b 100644 --- a/dask_cuda/tests/test_local_cuda_cluster.py +++ b/dask_cuda/tests/test_local_cuda_cluster.py @@ -87,14 +87,25 @@ def get_visible_devices(): } -@pytest.mark.parametrize("protocol", ["ucx", None]) @gen_test(timeout=20) -async def test_ucx_protocol(protocol): +async def test_ucx_protocol(): + pytest.importorskip("ucp") + + async with LocalCUDACluster( + protocol="ucx", asynchronous=True, data=dict + ) as cluster: + assert all( + ws.address.startswith("ucx://") for ws in cluster.scheduler.workers.values() + ) + + +@gen_test(timeout=20) +async def test_explicit_ucx_with_protocol_none(): pytest.importorskip("ucp") initialize(enable_tcp_over_ucx=True) async with LocalCUDACluster( - protocol=protocol, enable_tcp_over_ucx=True, asynchronous=True, data=dict + protocol=None, enable_tcp_over_ucx=True, asynchronous=True, data=dict ) as cluster: assert all( ws.address.startswith("ucx://") for ws in cluster.scheduler.workers.values() diff --git a/dask_cuda/tests/test_proxy.py b/dask_cuda/tests/test_proxy.py index 1a4abafe9..cfdbf636b 100644 --- a/dask_cuda/tests/test_proxy.py +++ b/dask_cuda/tests/test_proxy.py @@ -422,7 +422,6 @@ def task(x): async with dask_cuda.LocalCUDACluster( n_workers=1, protocol=protocol, - enable_tcp_over_ucx=protocol == "ucx", asynchronous=True, ) as cluster: async with Client(cluster, asynchronous=True) as client: @@ -462,7 +461,6 @@ def task(x): async with dask_cuda.LocalCUDACluster( n_workers=1, protocol=protocol, - enable_tcp_over_ucx=protocol == "ucx", asynchronous=True, ) as cluster: async with Client(cluster, asynchronous=True) as client: