diff --git a/dask_cuda/tests/test_spill.py b/dask_cuda/tests/test_spill.py index 6a542cfb9..6172b0bc6 100644 --- a/dask_cuda/tests/test_spill.py +++ b/dask_cuda/tests/test_spill.py @@ -1,3 +1,4 @@ +import gc import os from time import sleep @@ -58,7 +59,10 @@ def assert_device_host_file_size( def worker_assert( - dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead + total_size, + device_chunk_overhead, + serialized_chunk_overhead, + dask_worker=None, ): assert_device_host_file_size( dask_worker.data, total_size, device_chunk_overhead, serialized_chunk_overhead @@ -66,7 +70,10 @@ def worker_assert( def delayed_worker_assert( - dask_worker, total_size, device_chunk_overhead, serialized_chunk_overhead + total_size, + device_chunk_overhead, + serialized_chunk_overhead, + dask_worker=None, ): start = time() while not device_host_file_size_matches( @@ -82,6 +89,18 @@ def delayed_worker_assert( ) +def assert_host_chunks(spills_to_disk, dask_worker=None): + if spills_to_disk is False: + assert len(dask_worker.data.host) + + +def assert_disk_chunks(spills_to_disk, dask_worker=None): + if spills_to_disk is True: + assert len(dask_worker.data.disk or list()) > 0 + else: + assert len(dask_worker.data.disk or list()) == 0 + + @pytest.mark.parametrize( "params", [ @@ -122,7 +141,7 @@ def delayed_worker_assert( }, ], ) -@gen_test(timeout=120) +@gen_test(timeout=30) async def test_cupy_cluster_device_spill(params): cupy = pytest.importorskip("cupy") with dask.config.set( @@ -144,6 +163,8 @@ async def test_cupy_cluster_device_spill(params): ) as cluster: async with Client(cluster, asynchronous=True) as client: + await client.wait_for_workers(1) + rs = da.random.RandomState(RandomState=cupy.random.RandomState) x = rs.random(int(50e6), chunks=2e6) await wait(x) @@ -153,7 +174,10 @@ async def test_cupy_cluster_device_spill(params): # Allow up to 1024 bytes overhead per chunk serialized await client.run( - lambda dask_worker: worker_assert(dask_worker, x.nbytes, 1024, 1024) + worker_assert, + x.nbytes, + 1024, + 1024, ) y = client.compute(x.sum()) @@ -162,20 +186,19 @@ async def test_cupy_cluster_device_spill(params): assert (abs(res / x.size) - 0.5) < 1e-3 await client.run( - lambda dask_worker: worker_assert(dask_worker, x.nbytes, 1024, 1024) + worker_assert, + x.nbytes, + 1024, + 1024, ) - host_chunks = await client.run( - lambda dask_worker: len(dask_worker.data.host) + await client.run( + assert_host_chunks, + params["spills_to_disk"], ) - disk_chunks = await client.run( - lambda dask_worker: len(dask_worker.data.disk or list()) + await client.run( + assert_disk_chunks, + params["spills_to_disk"], ) - for hc, dc in zip(host_chunks.values(), disk_chunks.values()): - if params["spills_to_disk"]: - assert dc > 0 - else: - assert hc > 0 - assert dc == 0 @pytest.mark.parametrize( @@ -218,7 +241,7 @@ async def test_cupy_cluster_device_spill(params): }, ], ) -@gen_test(timeout=120) +@gen_test(timeout=30) async def test_cudf_cluster_device_spill(params): cudf = pytest.importorskip("cudf") @@ -243,6 +266,8 @@ async def test_cudf_cluster_device_spill(params): ) as cluster: async with Client(cluster, asynchronous=True) as client: + await client.wait_for_workers(1) + # There's a known issue with datetime64: # https://github.com/numpy/numpy/issues/4983#issuecomment-441332940 # The same error above happens when spilling datetime64 to disk @@ -264,26 +289,35 @@ async def test_cudf_cluster_device_spill(params): await wait(cdf2) del cdf + gc.collect() - host_chunks = await client.run( - lambda dask_worker: len(dask_worker.data.host) + await client.run( + assert_host_chunks, + params["spills_to_disk"], ) - disk_chunks = await client.run( - lambda dask_worker: len(dask_worker.data.disk or list()) + await client.run( + assert_disk_chunks, + params["spills_to_disk"], ) - for hc, dc in zip(host_chunks.values(), disk_chunks.values()): - if params["spills_to_disk"]: - assert dc > 0 - else: - assert hc > 0 - assert dc == 0 await client.run( - lambda dask_worker: worker_assert(dask_worker, nbytes, 32, 2048) + worker_assert, + nbytes, + 32, + 2048, ) del cdf2 - await client.run( - lambda dask_worker: delayed_worker_assert(dask_worker, 0, 0, 0) - ) + while True: + try: + await client.run( + delayed_worker_assert, + 0, + 0, + 0, + ) + except AssertionError: + gc.collect() + else: + break