Skip to content

Commit

Permalink
test _partitions
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Sep 26, 2023
1 parent 4aa35aa commit c6950a9
Showing 1 changed file with 29 additions and 13 deletions.
42 changes: 29 additions & 13 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def check_partitions(df, npartitions):
return True


def _test_dataframe_shuffle(backend, protocol, n_workers):
def _test_dataframe_shuffle(backend, protocol, n_workers, _partitions):
if backend == "cudf":
cudf = pytest.importorskip("cudf")

Expand All @@ -112,6 +112,9 @@ def _test_dataframe_shuffle(backend, protocol, n_workers):
if backend == "cudf":
df = cudf.DataFrame.from_pandas(df)

if _partitions:
df["_partitions"] = 0

for input_nparts in range(1, 5):
for output_nparts in range(1, 5):
ddf = dd.from_pandas(df.copy(), npartitions=input_nparts).persist(
Expand All @@ -123,33 +126,46 @@ def _test_dataframe_shuffle(backend, protocol, n_workers):
with dask.config.set(explicit_comms_batchsize=batchsize):
ddf = explicit_comms_shuffle(
ddf,
["key"],
["_partitions"] if _partitions else ["key"],
npartitions=output_nparts,
batchsize=batchsize,
).persist()

assert ddf.npartitions == output_nparts

# Check that each partition hashes to the same value
result = ddf.map_partitions(
check_partitions, output_nparts
).compute()
assert all(result.to_list())
if _partitions:
# If "_partitions" is the hash key, we except all but
# the first partition to be empty
assert_eq(ddf.partitions[0].compute(), df)
assert all(
len(ddf.partitions[i].compute()) == 0
for i in range(1, ddf.npartitions)
)
else:

# Check that each partition hashes to the same value
result = ddf.map_partitions(
check_partitions, output_nparts
).compute()
assert all(result.to_list())

# Check the values (ignoring the row order)
expected = df.sort_values("key")
got = ddf.compute().sort_values("key")
assert_eq(got, expected)
# Check the values (ignoring the row order)
expected = df.sort_values("key")
got = ddf.compute().sort_values("key")
assert_eq(got, expected)


@pytest.mark.parametrize("nworkers", [1, 2, 3])
@pytest.mark.parametrize("backend", ["pandas", "cudf"])
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
def test_dataframe_shuffle(backend, protocol, nworkers):
@pytest.mark.parametrize("_partitions", [True, False])
def test_dataframe_shuffle(backend, protocol, nworkers, _partitions):
if backend == "cudf":
pytest.importorskip("cudf")

p = mp.Process(target=_test_dataframe_shuffle, args=(backend, protocol, nworkers))
p = mp.Process(
target=_test_dataframe_shuffle, args=(backend, protocol, nworkers, _partitions)
)
p.start()
p.join()
assert not p.exitcode
Expand Down

0 comments on commit c6950a9

Please sign in to comment.