From bf60373519229de3b8494dbad9ff0bbb6813ed24 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 5 Jun 2023 22:00:06 +0200 Subject: [PATCH] Specify disk spill compression based on Dask config (#1190) Spill to disk compression was introduced in https://github.com/dask/distributed/pull/7768 and Dask-CUDA should also allow modifying the default compression via Dask config. This change is required to support `distributed>=2023.5.0`. Authors: - Peter Andreas Entschev (https://github.com/pentschev) Approvers: - GALI PREM SAGAR (https://github.com/galipremsagar) URL: https://github.com/rapidsai/dask-cuda/pull/1190 --- dask_cuda/device_host_file.py | 17 ++++++++++++++--- dask_cuda/tests/test_spill.py | 1 + 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/dask_cuda/device_host_file.py b/dask_cuda/device_host_file.py index a0fe92e8a..e8d8bc08b 100644 --- a/dask_cuda/device_host_file.py +++ b/dask_cuda/device_host_file.py @@ -1,4 +1,3 @@ -import functools import itertools import logging import os @@ -8,6 +7,7 @@ from zict import Buffer, File, Func from zict.common import ZictBase +import dask from distributed.protocol import ( dask_deserialize, dask_serialize, @@ -17,13 +17,24 @@ serialize_bytelist, ) from distributed.sizeof import safe_sizeof -from distributed.utils import nbytes +from distributed.utils import has_arg, nbytes from .is_device_object import is_device_object from .is_spillable_object import is_spillable_object from .utils import nvtx_annotate +def _serialize_bytelist(x, **kwargs): + kwargs["on_error"] = "raise" + + if has_arg(serialize_bytelist, "compression"): + compression = dask.config.get("distributed.worker.memory.spill-compression") + return serialize_bytelist(x, compression=compression, **kwargs) + else: + # For Distributed < 2023.5.0 compatibility + return serialize_bytelist(x, **kwargs) + + class LoggedBuffer(Buffer): """Extends zict.Buffer with logging capabilities @@ -192,7 +203,7 @@ def __init__( self.host_func = dict() self.disk_func = Func( - functools.partial(serialize_bytelist, on_error="raise"), + _serialize_bytelist, deserialize_bytes, File(self.disk_func_path), ) diff --git a/dask_cuda/tests/test_spill.py b/dask_cuda/tests/test_spill.py index d795f8f8d..cd36cb781 100644 --- a/dask_cuda/tests/test_spill.py +++ b/dask_cuda/tests/test_spill.py @@ -220,6 +220,7 @@ async def test_cudf_cluster_device_spill(params): { "distributed.comm.compression": False, "distributed.worker.memory.terminate": False, + "distributed.worker.memory.spill-compression": False, } ): async with LocalCUDACluster(