From fa9ad45929127893be6326826fa932e9fbc3a2b9 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 27 Sep 2023 13:11:00 +0200 Subject: [PATCH 1/3] decompress pickled messages --- distributed/protocol/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/distributed/protocol/core.py b/distributed/protocol/core.py index d58ee011297..58740af62ce 100644 --- a/distributed/protocol/core.py +++ b/distributed/protocol/core.py @@ -146,6 +146,8 @@ def _decode_default(obj): sub_header = msgpack.loads(frames[offset]) offset += 1 sub_frames = frames[offset : offset + sub_header["num-sub-frames"]] + if "compression" in sub_header: + sub_frames = decompress(sub_header, sub_frames) if allow_pickle: return pickle.loads(sub_header["pickled-obj"], buffers=sub_frames) else: From b6c1aa862326b50a39de23a1afc3435cb435fae3 Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 27 Sep 2023 14:22:27 +0200 Subject: [PATCH 2/3] added test --- distributed/comm/tests/test_ucx.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 5ecefcd6093..3863e16c82a 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -413,3 +413,22 @@ async def test_comm_closed_on_read_error(): await wait_for(reader.read(), 0.01) assert reader.closed() + + +@gen_test() +async def test_embedded_cupy_array( + ucx_loop, +): + cupy = pytest.importorskip("cupy") + da = pytest.importorskip("dask.array") + np = pytest.importorskip("numpy") + + async with LocalCluster( + protocol="ucx", n_workers=1, threads_per_worker=1, asynchronous=True + ) as cluster: + async with Client(cluster, asynchronous=True): + assert cluster.scheduler_address.startswith("ucx://") + a = cupy.arange(10000) + x = da.from_array(a, chunks=(10000,)) + await x + np.testing.assert_array_equal(a, x.compute()) From a051bdd87a770cd2004eae1df11ca042cf7f5f5e Mon Sep 17 00:00:00 2001 From: "Mads R. B. Kristensen" Date: Wed, 27 Sep 2023 15:17:56 +0200 Subject: [PATCH 3/3] clean up --- distributed/comm/tests/test_ucx.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index 3863e16c82a..aa5a2824eeb 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -426,9 +426,9 @@ async def test_embedded_cupy_array( async with LocalCluster( protocol="ucx", n_workers=1, threads_per_worker=1, asynchronous=True ) as cluster: - async with Client(cluster, asynchronous=True): + async with Client(cluster, asynchronous=True) as client: assert cluster.scheduler_address.startswith("ucx://") a = cupy.arange(10000) x = da.from_array(a, chunks=(10000,)) - await x - np.testing.assert_array_equal(a, x.compute()) + b = await client.compute(x) + cupy.testing.assert_array_equal(a, b)