diff --git a/dask_cuda/tests/test_dgx.py b/dask_cuda/tests/test_dgx.py index d57cf1a3..f8a2fce8 100644 --- a/dask_cuda/tests/test_dgx.py +++ b/dask_cuda/tests/test_dgx.py @@ -15,6 +15,10 @@ psutil = pytest.importorskip("psutil") +def _is_ucx_116(ucp): + return ucp.get_ucx_version()[:2] == (1, 16) + + class DGXVersion(Enum): DGX_1 = auto() DGX_2 = auto() @@ -102,9 +106,13 @@ def check_ucx_options(): ) def test_tcp_over_ucx(protocol): if protocol == "ucx": - pytest.importorskip("ucp") + ucp = pytest.importorskip("ucp") elif protocol == "ucxx": - pytest.importorskip("ucxx") + ucp = pytest.importorskip("ucxx") + if _is_ucx_116(ucp): + pytest.skip( + "Wireup may fail in UCX 1.16 in nodes with multiple NICs if TCP is used" + ) p = mp.Process(target=_test_tcp_over_ucx, args=(protocol,)) p.start() @@ -217,9 +225,13 @@ def check_ucx_options(): ) def test_ucx_infiniband_nvlink(protocol, params): if protocol == "ucx": - pytest.importorskip("ucp") + ucp = pytest.importorskip("ucp") elif protocol == "ucxx": - pytest.importorskip("ucxx") + ucp = pytest.importorskip("ucxx") + if _is_ucx_116(ucp) and params["enable_infiniband"] is False: + pytest.skip( + "Wireup may fail in UCX 1.16 in nodes with multiple NICs if TCP is used" + ) skip_queue = mp.Queue()