diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index e5f44cd88..27528dfce 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,6 +1,6 @@
repos:
- - repo: https://github.com/timothycrosley/isort
- rev: 5.0.7
+ - repo: https://github.com/pycqa/isort
+ rev: 5.6.4
hooks:
- id: isort
- repo: https://github.com/ambv/black
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 2898e7136..4a63da005 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,7 @@
+# dask-cuda 21.10.00 (Date TBD)
+
+Please see https://github.com/rapidsai/dask-cuda/releases/tag/v21.10.00a for the latest changes to this development branch.
+
# dask-cuda 21.08.00 (4 Aug 2021)
## 🐛 Bug Fixes
diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh
index e935da7fe..6626629d6 100755
--- a/ci/gpu/build.sh
+++ b/ci/gpu/build.sh
@@ -54,18 +54,18 @@ conda list --show-channel-urls
# Fixing Numpy version to avoid RuntimeWarning: numpy.ufunc size changed, may
# indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject
-gpuci_conda_retry install "cudatoolkit=$CUDA_REL" \
+gpuci_mamba_retry install "cudatoolkit=$CUDA_REL" \
"cudf=${MINOR_VERSION}" "dask-cudf=${MINOR_VERSION}" \
- "ucx-py=0.21.*" "ucx-proc=*=gpu" \
+ "ucx-py=0.22.*" "ucx-proc=*=gpu" \
"rapids-build-env=$MINOR_VERSION.*"
# Pin pytest-asyncio because latest versions modify the default asyncio
# `event_loop_policy`. See https://github.com/dask/distributed/pull/4212 .
-gpuci_conda_retry install "pytest-asyncio=<0.14.0"
+gpuci_mamba_retry install "pytest-asyncio=<0.14.0"
# https://docs.rapids.ai/maintainers/depmgmt/
-# gpuci_conda_retry remove -f rapids-build-env
-# gpuci_conda_retry install "your-pkg=1.0.0"
+# gpuci_mamba_retry remove -f rapids-build-env
+# gpuci_mamba_retry install "your-pkg=1.0.0"
conda info
@@ -106,24 +106,7 @@ else
gpuci_logger "Python pytest for dask-cuda"
cd "$WORKSPACE"
ls dask_cuda/tests/
- UCXPY_IFNAME=eth0 UCX_WARN_UNUSED_ENV_VARS=n UCX_MEMTYPE_CACHE=n pytest -vs -Werror::DeprecationWarning -Werror::FutureWarning --cache-clear --basetemp="$WORKSPACE/dask-cuda-tmp" --junitxml="$WORKSPACE/junit-dask-cuda.xml" --cov-config=.coveragerc --cov=dask_cuda --cov-report=xml:"$WORKSPACE/dask-cuda-coverage.xml" --cov-report term dask_cuda/tests/
-
- gpuci_logger "Running dask.distributed GPU tests"
- # Test downstream packages, which requires Python v3.7
- if [ $(python -c "import sys; print(sys.version_info[1])") -ge "7" ]; then
- # Clone Distributed to avoid pytest cleanup fixture errors
- # See https://github.com/dask/distributed/issues/4902
- gpuci_logger "Clone Distributed"
- git clone https://github.com/dask/distributed
-
- gpuci_logger "Run Distributed Tests"
- pytest --cache-clear -vs -Werror::DeprecationWarning -Werror::FutureWarning distributed/distributed/protocol/tests/test_cupy.py
- pytest --cache-clear -vs -Werror::DeprecationWarning -Werror::FutureWarning distributed/distributed/protocol/tests/test_numba.py
- pytest --cache-clear -vs -Werror::DeprecationWarning -Werror::FutureWarning distributed/distributed/protocol/tests/test_rmm.py
- pytest --cache-clear -vs -Werror::DeprecationWarning -Werror::FutureWarning distributed/distributed/protocol/tests/test_collection_cuda.py
- pytest --cache-clear -vs -Werror::DeprecationWarning -Werror::FutureWarning distributed/distributed/tests/test_nanny.py
- pytest --cache-clear -vs -Werror::DeprecationWarning -Werror::FutureWarning distributed/distributed/diagnostics/tests/test_nvml.py
- fi
+ DASK_CUDA_TEST_SINGLE_GPU=1 UCXPY_IFNAME=eth0 UCX_WARN_UNUSED_ENV_VARS=n UCX_MEMTYPE_CACHE=n pytest -vs -Werror::DeprecationWarning -Werror::FutureWarning --cache-clear --basetemp="$WORKSPACE/dask-cuda-tmp" --junitxml="$WORKSPACE/junit-dask-cuda.xml" --cov-config=.coveragerc --cov=dask_cuda --cov-report=xml:"$WORKSPACE/dask-cuda-coverage.xml" --cov-report term dask_cuda/tests/
logger "Run local benchmark..."
python dask_cuda/benchmarks/local_cudf_shuffle.py --partition-size="1 KiB" -d 0 --runs 1 --backend dask
diff --git a/conda/recipes/dask-cuda/meta.yaml b/conda/recipes/dask-cuda/meta.yaml
index e5b7f0609..fbd79bf09 100644
--- a/conda/recipes/dask-cuda/meta.yaml
+++ b/conda/recipes/dask-cuda/meta.yaml
@@ -27,8 +27,8 @@ requirements:
- setuptools
run:
- python
- - dask >=2.22.0,<=2021.07.1
- - distributed >=2.22.0,<=2021.07.1
+ - dask=2021.09.1
+ - distributed=2021.09.1
- pynvml >=8.0.3
- numpy >=1.16.0
- numba >=0.53.1
diff --git a/dask_cuda/benchmarks/local_cudf_merge.py b/dask_cuda/benchmarks/local_cudf_merge.py
index e6e301905..f36be7478 100644
--- a/dask_cuda/benchmarks/local_cudf_merge.py
+++ b/dask_cuda/benchmarks/local_cudf_merge.py
@@ -1,6 +1,7 @@
import contextlib
import math
from collections import defaultdict
+from json import dumps
from time import perf_counter
from warnings import filterwarnings
@@ -278,6 +279,8 @@ def main(args):
print(f"broadcast | {broadcast}")
print(f"protocol | {args.protocol}")
print(f"device(s) | {args.devs}")
+ if args.device_memory_limit:
+ print(f"memory-limit | {format_bytes(args.device_memory_limit)}")
print(f"rmm-pool | {(not args.disable_rmm_pool)}")
print(f"frac-match | {args.frac_match}")
if args.protocol == "ucx":
@@ -304,18 +307,59 @@ def main(args):
if args.backend == "dask":
if args.markdown:
print("\nWorker-Worker Transfer Rates
\n\n```")
- print("(w1,w2) | 25% 50% 75% (total nbytes)")
+ print("(w1,w2) | 25% 50% 75% (total nbytes)")
print("-------------------------------")
for (d1, d2), bw in sorted(bandwidths.items()):
fmt = (
- "(%s,%s) | %s %s %s (%s)"
+ "(%s,%s) | %s %s %s (%s)"
if args.multi_node or args.sched_addr
- else "(%02d,%02d) | %s %s %s (%s)"
+ else "(%02d,%02d) | %s %s %s (%s)"
)
print(fmt % (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]))
if args.markdown:
print("```\n \n")
+ if args.benchmark_json:
+ bandwidths_json = {
+ "bandwidth_({d1},{d2})_{i}"
+ if args.multi_node or args.sched_addr
+ else "(%02d,%02d)_%s" % (d1, d2, i): parse_bytes(v.rstrip("/s"))
+ for (d1, d2), bw in sorted(bandwidths.items())
+ for i, v in zip(
+ ["25%", "50%", "75%", "total_nbytes"],
+ [bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]],
+ )
+ }
+
+ with open(args.benchmark_json, "a") as fp:
+ for data_processed, took in took_list:
+ fp.write(
+ dumps(
+ dict(
+ {
+ "backend": args.backend,
+ "merge_type": args.type,
+ "rows_per_chunk": args.chunk_size,
+ "base_chunks": args.base_chunks,
+ "other_chunks": args.other_chunks,
+ "broadcast": broadcast,
+ "protocol": args.protocol,
+ "devs": args.devs,
+ "device_memory_limit": args.device_memory_limit,
+ "rmm_pool": not args.disable_rmm_pool,
+ "tcp": args.enable_tcp_over_ucx,
+ "ib": args.enable_infiniband,
+ "nvlink": args.enable_nvlink,
+ "data_processed": data_processed,
+ "wall_clock": took,
+ "throughput": data_processed / took,
+ },
+ **bandwidths_json,
+ )
+ )
+ + "\n"
+ )
+
if args.multi_node:
client.shutdown()
client.close()
diff --git a/dask_cuda/benchmarks/local_cudf_shuffle.py b/dask_cuda/benchmarks/local_cudf_shuffle.py
index f329aa92b..f2c812d08 100644
--- a/dask_cuda/benchmarks/local_cudf_shuffle.py
+++ b/dask_cuda/benchmarks/local_cudf_shuffle.py
@@ -1,5 +1,6 @@
import contextlib
from collections import defaultdict
+from json import dumps
from time import perf_counter as clock
from warnings import filterwarnings
@@ -151,6 +152,8 @@ def main(args):
print(f"in-parts | {args.in_parts}")
print(f"protocol | {args.protocol}")
print(f"device(s) | {args.devs}")
+ if args.device_memory_limit:
+ print(f"memory-limit | {format_bytes(args.device_memory_limit)}")
print(f"rmm-pool | {(not args.disable_rmm_pool)}")
if args.protocol == "ucx":
print(f"tcp | {args.enable_tcp_over_ucx}")
@@ -176,18 +179,56 @@ def main(args):
if args.backend == "dask":
if args.markdown:
print("\nWorker-Worker Transfer Rates
\n\n```")
- print("(w1,w2) | 25% 50% 75% (total nbytes)")
+ print("(w1,w2) | 25% 50% 75% (total nbytes)")
print("-------------------------------")
for (d1, d2), bw in sorted(bandwidths.items()):
fmt = (
- "(%s,%s) | %s %s %s (%s)"
+ "(%s,%s) | %s %s %s (%s)"
if args.multi_node or args.sched_addr
- else "(%02d,%02d) | %s %s %s (%s)"
+ else "(%02d,%02d) | %s %s %s (%s)"
)
print(fmt % (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]))
if args.markdown:
print("```\n \n")
+ if args.benchmark_json:
+ bandwidths_json = {
+ "bandwidth_({d1},{d2})_{i}"
+ if args.multi_node or args.sched_addr
+ else "(%02d,%02d)_%s" % (d1, d2, i): parse_bytes(v.rstrip("/s"))
+ for (d1, d2), bw in sorted(bandwidths.items())
+ for i, v in zip(
+ ["25%", "50%", "75%", "total_nbytes"],
+ [bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]],
+ )
+ }
+
+ with open(args.benchmark_json, "a") as fp:
+ for data_processed, took in took_list:
+ fp.write(
+ dumps(
+ dict(
+ {
+ "backend": args.backend,
+ "partition_size": args.partition_size,
+ "in_parts": args.in_parts,
+ "protocol": args.protocol,
+ "devs": args.devs,
+ "device_memory_limit": args.device_memory_limit,
+ "rmm_pool": not args.disable_rmm_pool,
+ "tcp": args.enable_tcp_over_ucx,
+ "ib": args.enable_infiniband,
+ "nvlink": args.enable_nvlink,
+ "data_processed": data_processed,
+ "wall_clock": took,
+ "throughput": data_processed / took,
+ },
+ **bandwidths_json,
+ )
+ )
+ + "\n"
+ )
+
if args.multi_node:
client.shutdown()
client.close()
diff --git a/dask_cuda/benchmarks/local_cupy.py b/dask_cuda/benchmarks/local_cupy.py
index 9a07b2afe..a4bbc341a 100644
--- a/dask_cuda/benchmarks/local_cupy.py
+++ b/dask_cuda/benchmarks/local_cupy.py
@@ -1,6 +1,6 @@
import asyncio
from collections import defaultdict
-from json import dump
+from json import dumps
from time import perf_counter as clock
from warnings import filterwarnings
@@ -246,6 +246,8 @@ async def run(args):
print(f"Ignore-size | {format_bytes(args.ignore_size)}")
print(f"Protocol | {args.protocol}")
print(f"Device(s) | {args.devs}")
+ if args.device_memory_limit:
+ print(f"Memory limit | {format_bytes(args.device_memory_limit)}")
print(f"Worker Thread(s) | {args.threads_per_worker}")
print("==========================")
print("Wall-clock | npartitions")
@@ -266,37 +268,46 @@ async def run(args):
print(fmt % (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]))
if args.benchmark_json:
-
- d = {
- "operation": args.operation,
- "size": args.size,
- "second_size": args.second_size,
- "chunk_size": args.chunk_size,
- "compute_size": size,
- "compute_chunk_size": chunksize,
- "ignore_size": format_bytes(args.ignore_size),
- "protocol": args.protocol,
- "devs": args.devs,
- "threads_per_worker": args.threads_per_worker,
- "times": [
- {"wall_clock": took, "npartitions": npartitions}
- for (took, npartitions) in took_list
- ],
- "bandwidths": {
- f"({d1},{d2})"
- if args.multi_node or args.sched_addr
- else "(%02d,%02d)"
- % (d1, d2): {
- "25%": bw[0],
- "50%": bw[1],
- "75%": bw[2],
- "total_nbytes": total_nbytes[(d1, d2)],
- }
- for (d1, d2), bw in sorted(bandwidths.items())
- },
+ bandwidths_json = {
+ "bandwidth_({d1},{d2})_{i}"
+ if args.multi_node or args.sched_addr
+ else "(%02d,%02d)_%s" % (d1, d2, i): parse_bytes(v.rstrip("/s"))
+ for (d1, d2), bw in sorted(bandwidths.items())
+ for i, v in zip(
+ ["25%", "50%", "75%", "total_nbytes"],
+ [bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]],
+ )
}
- with open(args.benchmark_json, "w") as fp:
- dump(d, fp, indent=2)
+
+ with open(args.benchmark_json, "a") as fp:
+ for took, npartitions in took_list:
+ fp.write(
+ dumps(
+ dict(
+ {
+ "operation": args.operation,
+ "user_size": args.size,
+ "user_second_size": args.second_size,
+ "user_chunk_size": args.chunk_size,
+ "compute_size": size,
+ "compute_chunk_size": chunksize,
+ "ignore_size": args.ignore_size,
+ "protocol": args.protocol,
+ "devs": args.devs,
+ "device_memory_limit": args.device_memory_limit,
+ "worker_threads": args.threads_per_worker,
+ "rmm_pool": not args.disable_rmm_pool,
+ "tcp": args.enable_tcp_over_ucx,
+ "ib": args.enable_infiniband,
+ "nvlink": args.enable_nvlink,
+ "wall_clock": took,
+ "npartitions": npartitions,
+ },
+ **bandwidths_json,
+ )
+ )
+ + "\n"
+ )
# An SSHCluster will not automatically shut down, we have to
# ensure it does.
@@ -353,12 +364,6 @@ def parse_args():
"type": int,
"help": "Number of runs (default 3).",
},
- {
- "name": "--benchmark-json",
- "default": None,
- "type": str,
- "help": "Dump a JSON report of benchmarks (optional).",
- },
]
return parse_benchmark_args(
diff --git a/dask_cuda/benchmarks/local_cupy_map_overlap.py b/dask_cuda/benchmarks/local_cupy_map_overlap.py
index 374049ff7..077b212fb 100644
--- a/dask_cuda/benchmarks/local_cupy_map_overlap.py
+++ b/dask_cuda/benchmarks/local_cupy_map_overlap.py
@@ -1,5 +1,6 @@
import asyncio
from collections import defaultdict
+from json import dumps
from time import perf_counter as clock
from warnings import filterwarnings
@@ -125,29 +126,69 @@ async def run(args):
print("Roundtrip benchmark")
print("--------------------------")
- print(f"Size | {args.size}*{args.size}")
- print(f"Chunk-size | {args.chunk_size}")
- print(f"Ignore-size | {format_bytes(args.ignore_size)}")
- print(f"Protocol | {args.protocol}")
- print(f"Device(s) | {args.devs}")
+ print(f"Size | {args.size}*{args.size}")
+ print(f"Chunk-size | {args.chunk_size}")
+ print(f"Ignore-size | {format_bytes(args.ignore_size)}")
+ print(f"Protocol | {args.protocol}")
+ print(f"Device(s) | {args.devs}")
+ if args.device_memory_limit:
+ print(f"memory-limit | {format_bytes(args.device_memory_limit)}")
print("==========================")
- print("Wall-clock | npartitions")
+ print("Wall-clock | npartitions")
print("--------------------------")
for (took, npartitions) in took_list:
t = format_time(took)
- t += " " * (11 - len(t))
+ t += " " * (12 - len(t))
print(f"{t} | {npartitions}")
print("==========================")
- print("(w1,w2) | 25% 50% 75% (total nbytes)")
+ print("(w1,w2) | 25% 50% 75% (total nbytes)")
print("--------------------------")
for (d1, d2), bw in sorted(bandwidths.items()):
fmt = (
- "(%s,%s) | %s %s %s (%s)"
+ "(%s,%s) | %s %s %s (%s)"
if args.multi_node or args.sched_addr
- else "(%02d,%02d) | %s %s %s (%s)"
+ else "(%02d,%02d) | %s %s %s (%s)"
)
print(fmt % (d1, d2, bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]))
+ if args.benchmark_json:
+ bandwidths_json = {
+ "bandwidth_({d1},{d2})_{i}"
+ if args.multi_node or args.sched_addr
+ else "(%02d,%02d)_%s" % (d1, d2, i): parse_bytes(v.rstrip("/s"))
+ for (d1, d2), bw in sorted(bandwidths.items())
+ for i, v in zip(
+ ["25%", "50%", "75%", "total_nbytes"],
+ [bw[0], bw[1], bw[2], total_nbytes[(d1, d2)]],
+ )
+ }
+
+ with open(args.benchmark_json, "a") as fp:
+ for took, npartitions in took_list:
+ fp.write(
+ dumps(
+ dict(
+ {
+ "size": args.size * args.size,
+ "chunk_size": args.chunk_size,
+ "ignore_size": args.ignore_size,
+ "protocol": args.protocol,
+ "devs": args.devs,
+ "device_memory_limit": args.device_memory_limit,
+ "worker_threads": args.threads_per_worker,
+ "rmm_pool": not args.disable_rmm_pool,
+ "tcp": args.enable_tcp_over_ucx,
+ "ib": args.enable_infiniband,
+ "nvlink": args.enable_nvlink,
+ "wall_clock": took,
+ "npartitions": npartitions,
+ },
+ **bandwidths_json,
+ )
+ )
+ + "\n"
+ )
+
# An SSHCluster will not automatically shut down, we have to
# ensure it does.
if args.multi_node:
diff --git a/dask_cuda/benchmarks/utils.py b/dask_cuda/benchmarks/utils.py
index 4ee44820e..4cbe574c4 100644
--- a/dask_cuda/benchmarks/utils.py
+++ b/dask_cuda/benchmarks/utils.py
@@ -34,6 +34,16 @@ def parse_benchmark_args(description="Generic dask-cuda Benchmark", args_list=[]
type=str,
help="Write dask profile report (E.g. dask-report.html)",
)
+ parser.add_argument(
+ "--device-memory-limit",
+ default=None,
+ type=parse_bytes,
+ help="Size of the CUDA device LRU cache, which is used to determine when the "
+ "worker starts spilling to host memory. Can be an integer (bytes), float "
+ "(fraction of total device memory), string (like ``'5GB'`` or ``'5000M'``), or "
+ "``'auto'``, 0, or ``None`` to disable spilling to host (i.e. allow full "
+ "device memory usage).",
+ )
parser.add_argument(
"--rmm-pool-size",
default=None,
@@ -156,6 +166,13 @@ def parse_benchmark_args(description="Generic dask-cuda Benchmark", args_list=[]
type=str,
help="Generate plot output written to defined directory",
)
+ parser.add_argument(
+ "--benchmark-json",
+ default=None,
+ type=str,
+ help="Dump a line-delimited JSON report of benchmarks to this file (optional). "
+ "Creates file if it does not exist, appends otherwise.",
+ )
for args in args_list:
name = args.pop("name")
@@ -203,6 +220,8 @@ def get_cluster_options(args):
if args.enable_rdmacm:
worker_options["enable_rdmacm"] = ""
+ if args.device_memory_limit:
+ worker_options["device_memory_limit"] = args.device_memory_limit
if args.ucx_net_devices:
worker_options["ucx_net_devices"] = args.ucx_net_devices
@@ -229,6 +248,7 @@ def get_cluster_options(args):
"enable_nvlink": args.enable_nvlink,
"enable_rdmacm": args.enable_rdmacm,
"interface": args.interface,
+ "device_memory_limit": args.device_memory_limit,
}
if args.no_silence_logs:
cluster_kwargs["silence_logs"] = False
diff --git a/dask_cuda/cli/dask_cuda_worker.py b/dask_cuda/cli/dask_cuda_worker.py
index 8c48d4716..35bb703e7 100755
--- a/dask_cuda/cli/dask_cuda_worker.py
+++ b/dask_cuda/cli/dask_cuda_worker.py
@@ -142,6 +142,17 @@
``dask.temporary-directory`` in the local Dask configuration, using the current
working directory if this is not set.""",
)
+@click.option(
+ "--shared-filesystem/--no-shared-filesystem",
+ default=None,
+ type=bool,
+ help="""If `--shared-filesystem` is specified, inform JIT-Unspill that
+ `local_directory` is a shared filesystem available for all workers, whereas
+ `--no-shared-filesystem` informs it may not assume it's a shared filesystem.
+ If neither is specified, JIT-Unspill will decide based on the Dask config value
+ specified by `"jit-unspill-shared-fs"`.
+ Notice, a shared filesystem must support the `os.link()` operation.""",
+)
@click.option(
"--scheduler-file",
type=str,
@@ -274,6 +285,7 @@ def main(
dashboard,
dashboard_address,
local_directory,
+ shared_filesystem,
scheduler_file,
interface,
preload,
@@ -323,6 +335,7 @@ def main(
dashboard,
dashboard_address,
local_directory,
+ shared_filesystem,
scheduler_file,
interface,
preload,
diff --git a/dask_cuda/cuda_worker.py b/dask_cuda/cuda_worker.py
index 05f0b5154..0b6d1d6be 100644
--- a/dask_cuda/cuda_worker.py
+++ b/dask_cuda/cuda_worker.py
@@ -66,6 +66,7 @@ def __init__(
dashboard=True,
dashboard_address=":0",
local_directory=None,
+ shared_filesystem=None,
scheduler_file=None,
interface=None,
preload=[],
@@ -199,6 +200,9 @@ def del_pid_file():
"device_memory_limit": parse_device_memory_limit(
device_memory_limit, device_index=i
),
+ "memory_limit": memory_limit,
+ "local_directory": local_directory,
+ "shared_filesystem": shared_filesystem,
},
)
else:
@@ -241,7 +245,7 @@ def del_pid_file():
name=name if nprocs == 1 or not name else str(name) + "-" + str(i),
local_directory=local_directory,
config={
- "ucx": get_ucx_config(
+ "distributed.comm.ucx": get_ucx_config(
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_infiniband=enable_infiniband,
enable_nvlink=enable_nvlink,
diff --git a/dask_cuda/device_host_file.py b/dask_cuda/device_host_file.py
index 5e2463be0..c03fa2973 100644
--- a/dask_cuda/device_host_file.py
+++ b/dask_cuda/device_host_file.py
@@ -175,14 +175,12 @@ def __init__(
local_directory=None,
log_spilling=False,
):
- if local_directory is None:
- local_directory = dask.config.get("temporary-directory") or os.getcwd()
-
- if local_directory and not os.path.exists(local_directory):
- os.makedirs(local_directory, exist_ok=True)
- local_directory = os.path.join(local_directory, "dask-worker-space")
-
- self.disk_func_path = os.path.join(local_directory, "storage")
+ self.disk_func_path = os.path.join(
+ local_directory or dask.config.get("temporary-directory") or os.getcwd(),
+ "dask-worker-space",
+ "storage",
+ )
+ os.makedirs(self.disk_func_path, exist_ok=True)
self.host_func = dict()
self.disk_func = Func(
diff --git a/dask_cuda/explicit_comms/comms.py b/dask_cuda/explicit_comms/comms.py
index 1de033e32..dd001a3d6 100644
--- a/dask_cuda/explicit_comms/comms.py
+++ b/dask_cuda/explicit_comms/comms.py
@@ -34,9 +34,7 @@ def get_multi_lock_or_null_context(multi_lock_context, *args, **kwargs):
return MultiLock(*args, **kwargs)
else:
- # Use a null context that doesn't do anything
- # TODO: use `contextlib.nullcontext()` from Python 3.7+
- return contextlib.suppress()
+ return contextlib.nullcontext()
def default_comms(client: Optional[Client] = None) -> "CommsContext":
diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py
index aeea71467..cce5480e7 100644
--- a/dask_cuda/explicit_comms/dataframe/shuffle.py
+++ b/dask_cuda/explicit_comms/dataframe/shuffle.py
@@ -18,7 +18,7 @@
from distributed import wait
from distributed.protocol import nested_deserialize, to_serialize
-from ...proxify_host_file import ProxifyHostFile
+from ...proxify_host_file import ProxyManager
from .. import comms
@@ -148,19 +148,17 @@ async def local_shuffle(
eps = s["eps"]
try:
- hostfile = first(iter(in_parts[0].values()))._obj_pxy.get(
- "hostfile", lambda: None
- )()
+ manager = first(iter(in_parts[0].values()))._obj_pxy.get("manager", None)
except AttributeError:
- hostfile = None
+ manager = None
- if isinstance(hostfile, ProxifyHostFile):
+ if isinstance(manager, ProxyManager):
def concat(args, ignore_index=False):
if len(args) < 2:
return args[0]
- return hostfile.add_external(dd_concat(args, ignore_index=ignore_index))
+ return manager.proxify(dd_concat(args, ignore_index=ignore_index))
else:
concat = dd_concat
diff --git a/dask_cuda/get_device_memory_objects.py b/dask_cuda/get_device_memory_objects.py
index deba96a06..385f70793 100644
--- a/dask_cuda/get_device_memory_objects.py
+++ b/dask_cuda/get_device_memory_objects.py
@@ -28,10 +28,7 @@ def get_device_memory_objects(obj) -> set:
@dispatch.register(object)
def get_device_memory_objects_default(obj):
if hasattr(obj, "_obj_pxy"):
- if obj._obj_pxy["serializers"] is None:
- return dispatch(obj._obj_pxy["obj"])
- else:
- return []
+ return dispatch(obj._obj_pxy["obj"])
if hasattr(obj, "data"):
return dispatch(obj.data)
if hasattr(obj, "_owner") and obj._owner is not None:
diff --git a/dask_cuda/initialize.py b/dask_cuda/initialize.py
index 416a7d6e1..1cb58c757 100644
--- a/dask_cuda/initialize.py
+++ b/dask_cuda/initialize.py
@@ -1,15 +1,72 @@
import logging
+import os
+import warnings
import click
import numba.cuda
import dask
+import distributed.comm.ucx
+from distributed.diagnostics.nvml import has_cuda_context
from .utils import get_ucx_config
logger = logging.getLogger(__name__)
+def _create_cuda_context_handler():
+ if int(os.environ.get("DASK_CUDA_TEST_SINGLE_GPU", "0")) != 0:
+ try:
+ numba.cuda.current_context()
+ except numba.cuda.cudadrv.error.CudaSupportError:
+ pass
+ else:
+ numba.cuda.current_context()
+
+
+def _create_cuda_context():
+ try:
+ # Added here to ensure the parent `LocalCUDACluster` process creates the CUDA
+ # context directly from the UCX module, thus avoiding a similar warning there.
+ try:
+ distributed.comm.ucx.init_once()
+ except ModuleNotFoundError:
+ # UCX intialization has to be delegated to Distributed, it will take care
+ # of setting correct environment variables and importing `ucp` after that.
+ # Therefore if ``import ucp`` fails we can just continue here.
+ pass
+
+ cuda_visible_device = int(
+ os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
+ )
+ ctx = has_cuda_context()
+ if ctx is not False and distributed.comm.ucx.cuda_context_created is False:
+ warnings.warn(
+ f"A CUDA context for device {ctx} already exists on process ID "
+ f"{os.getpid()}. This is often the result of a CUDA-enabled library "
+ "calling a CUDA runtime function before Dask-CUDA can spawn worker "
+ "processes. Please make sure any such function calls don't happen at "
+ "import time or in the global scope of a program."
+ )
+
+ _create_cuda_context_handler()
+
+ if distributed.comm.ucx.cuda_context_created is False:
+ ctx = has_cuda_context()
+ if ctx is not False and ctx != cuda_visible_device:
+ warnings.warn(
+ f"Worker with process ID {os.getpid()} should have a CUDA context "
+ f"assigned to device {cuda_visible_device}, but instead the CUDA "
+ f"context is on device {ctx}. This is often the result of a "
+ "CUDA-enabled library calling a CUDA runtime function before "
+ "Dask-CUDA can spawn worker processes. Please make sure any such "
+ "function calls don't happen at import time or in the global scope "
+ "of a program."
+ )
+ except Exception:
+ logger.error("Unable to start CUDA Context", exc_info=True)
+
+
def initialize(
create_cuda_context=True,
enable_tcp_over_ucx=False,
@@ -77,13 +134,6 @@ def initialize(
it is callable. Can be an integer or ``None`` if ``net_devices`` is not
callable.
"""
-
- if create_cuda_context:
- try:
- numba.cuda.current_context()
- except Exception:
- logger.error("Unable to start CUDA Context", exc_info=True)
-
ucx_config = get_ucx_config(
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_infiniband=enable_infiniband,
@@ -92,7 +142,10 @@ def initialize(
net_devices=net_devices,
cuda_device_index=cuda_device_index,
)
- dask.config.update(dask.config.global_config, {"ucx": ucx_config}, priority="new")
+ dask.config.set({"distributed.comm.ucx": ucx_config})
+
+ if create_cuda_context:
+ _create_cuda_context()
@click.command()
@@ -138,7 +191,4 @@ def dask_setup(
net_devices,
):
if create_cuda_context:
- try:
- numba.cuda.current_context()
- except Exception:
- logger.error("Unable to start CUDA Context", exc_info=True)
+ _create_cuda_context()
diff --git a/dask_cuda/is_device_object.py b/dask_cuda/is_device_object.py
index 0654b4b4d..ab5844b79 100644
--- a/dask_cuda/is_device_object.py
+++ b/dask_cuda/is_device_object.py
@@ -35,6 +35,6 @@ def is_device_object_cudf_dataframe(df):
def is_device_object_cudf_series(s):
return True
- @is_device_object.register(cudf.Index)
+ @is_device_object.register(cudf.BaseIndex)
def is_device_object_cudf_index(s):
return True
diff --git a/dask_cuda/local_cuda_cluster.py b/dask_cuda/local_cuda_cluster.py
index 26831f60d..9ee4bbb6c 100644
--- a/dask_cuda/local_cuda_cluster.py
+++ b/dask_cuda/local_cuda_cluster.py
@@ -77,6 +77,12 @@ class LocalCUDACluster(LocalCluster):
``"path/to/files"``) or ``None`` to fall back on the value of
``dask.temporary-directory`` in the local Dask configuration, using the current
working directory if this is not set.
+ shared_filesystem: bool or None, default None
+ Whether the `local_directory` above is shared between all workers or not.
+ If ``None``, the "jit-unspill-shared-fs" config value are used, which
+ defaults to True. Notice, in all other cases this option defaults to False,
+ but on a local cluster it defaults to True -- we assume all workers use the
+ same filesystem.
protocol : str or None, default None
Protocol to use for communication. Can be a string (like ``"tcp"`` or
``"ucx"``), or ``None`` to automatically choose the correct protocol.
@@ -180,6 +186,7 @@ def __init__(
device_memory_limit=0.8,
data=None,
local_directory=None,
+ shared_filesystem=None,
protocol=None,
enable_tcp_over_ucx=False,
enable_infiniband=False,
@@ -213,7 +220,7 @@ def __init__(
n_workers = len(CUDA_VISIBLE_DEVICES)
if n_workers < 1:
raise ValueError("Number of workers cannot be less than 1.")
- self.host_memory_limit = parse_memory_limit(
+ self.memory_limit = parse_memory_limit(
memory_limit, threads_per_worker, n_workers
)
self.device_memory_limit = parse_device_memory_limit(
@@ -260,22 +267,29 @@ def __init__(
else:
self.jit_unspill = jit_unspill
+ if shared_filesystem is None:
+ # Notice, we assume a shared filesystem
+ shared_filesystem = dask.config.get("jit-unspill-shared-fs", default=True)
+
data = kwargs.pop("data", None)
if data is None:
if self.jit_unspill:
data = (
ProxifyHostFile,
- {"device_memory_limit": self.device_memory_limit,},
+ {
+ "device_memory_limit": self.device_memory_limit,
+ "memory_limit": self.memory_limit,
+ "local_directory": local_directory,
+ "shared_filesystem": shared_filesystem,
+ },
)
else:
data = (
DeviceHostFile,
{
"device_memory_limit": self.device_memory_limit,
- "memory_limit": self.host_memory_limit,
- "local_directory": local_directory
- or dask.config.get("temporary-directory")
- or os.getcwd(),
+ "memory_limit": self.memory_limit,
+ "local_directory": local_directory,
"log_spilling": log_spilling,
},
)
@@ -309,7 +323,7 @@ def __init__(
elif ucx_net_devices == "":
raise ValueError("ucx_net_devices can not be an empty string")
self.ucx_net_devices = ucx_net_devices
- self.set_ucx_net_devices = enable_infiniband
+ self.set_ucx_net_devices = enable_infiniband and ucx_net_devices is not None
self.host = kwargs.get("host", None)
initialize(
@@ -332,14 +346,14 @@ def __init__(
super().__init__(
n_workers=0,
threads_per_worker=threads_per_worker,
- memory_limit=self.host_memory_limit,
+ memory_limit=self.memory_limit,
processes=True,
data=data,
local_directory=local_directory,
protocol=protocol,
worker_class=worker_class,
config={
- "ucx": get_ucx_config(
+ "distributed.comm.ucx": get_ucx_config(
enable_tcp_over_ucx=enable_tcp_over_ucx,
enable_nvlink=enable_nvlink,
enable_infiniband=enable_infiniband,
@@ -394,7 +408,9 @@ def new_worker_spec(self):
net_dev = get_ucx_net_devices(cuda_device_index, self.ucx_net_devices)
if net_dev is not None:
spec["options"]["env"]["UCX_NET_DEVICES"] = net_dev
- spec["options"]["config"]["ucx"]["net-devices"] = net_dev
+ spec["options"]["config"]["distributed.comm.ucx"][
+ "net-devices"
+ ] = net_dev
spec["options"]["interface"] = get_ucx_net_devices(
cuda_device_index,
diff --git a/dask_cuda/proxify_device_objects.py b/dask_cuda/proxify_device_objects.py
index 92a92c95e..1ec7480a4 100644
--- a/dask_cuda/proxify_device_objects.py
+++ b/dask_cuda/proxify_device_objects.py
@@ -165,14 +165,9 @@ def wrapper(*args, **kwargs):
def proxify(obj, proxied_id_to_proxy, found_proxies, subclass=None):
_id = id(obj)
- if _id in proxied_id_to_proxy:
- ret = proxied_id_to_proxy[_id]
- finalize = ret._obj_pxy.get("external_finalize", None)
- if finalize:
- finalize()
- proxied_id_to_proxy[_id] = ret = asproxy(obj, subclass=subclass)
- else:
- proxied_id_to_proxy[_id] = ret = asproxy(obj, subclass=subclass)
+ if _id not in proxied_id_to_proxy:
+ proxied_id_to_proxy[_id] = asproxy(obj, subclass=subclass)
+ ret = proxied_id_to_proxy[_id]
found_proxies.append(ret)
return ret
@@ -190,11 +185,6 @@ def proxify_device_object_default(
def proxify_device_object_proxy_object(
obj, proxied_id_to_proxy, found_proxies, excl_proxies
):
- # We deserialize CUDA-serialized objects since it is very cheap and
- # makes it easy to administrate device memory usage
- if obj._obj_pxy_is_serialized() and "cuda" in obj._obj_pxy["serializers"]:
- obj._obj_pxy_deserialize()
-
# Check if `obj` is already known
if not obj._obj_pxy_is_serialized():
_id = id(obj._obj_pxy["obj"])
@@ -203,14 +193,6 @@ def proxify_device_object_proxy_object(
else:
proxied_id_to_proxy[_id] = obj
- finalize = obj._obj_pxy.get("external_finalize", None)
- if finalize:
- finalize()
- obj = obj._obj_pxy_copy()
- if not obj._obj_pxy_is_serialized():
- _id = id(obj._obj_pxy["obj"])
- proxied_id_to_proxy[_id] = obj
-
if not excl_proxies:
found_proxies.append(obj)
return obj
@@ -257,10 +239,19 @@ class FrameProxyObject(ProxyObject, cudf._lib.table.Table):
@dispatch.register(cudf.DataFrame)
@dispatch.register(cudf.Series)
- @dispatch.register(cudf.Index)
+ @dispatch.register(cudf.BaseIndex)
def proxify_device_object_cudf_dataframe(
obj, proxied_id_to_proxy, found_proxies, excl_proxies
):
return proxify(
obj, proxied_id_to_proxy, found_proxies, subclass=FrameProxyObject
)
+
+ try:
+ from dask.array.dispatch import percentile_lookup
+
+ from dask_cudf.backends import percentile_cudf
+
+ percentile_lookup.register(FrameProxyObject, percentile_cudf)
+ except ImportError:
+ pass
diff --git a/dask_cuda/proxify_host_file.py b/dask_cuda/proxify_host_file.py
index 951740cd7..2ebf4fc46 100644
--- a/dask_cuda/proxify_host_file.py
+++ b/dask_cuda/proxify_host_file.py
@@ -1,124 +1,353 @@
+import abc
+import logging
+import os
import threading
import time
+import uuid
+import warnings
import weakref
from collections import defaultdict
from typing import (
+ Any,
DefaultDict,
Dict,
Hashable,
Iterator,
List,
MutableMapping,
+ Optional,
Set,
Tuple,
)
+from weakref import ReferenceType
import dask
from dask.sizeof import sizeof
+from distributed.protocol.compression import decompress, maybe_compress
+from distributed.protocol.serialize import (
+ merge_and_deserialize,
+ register_serialization_family,
+ serialize_and_split,
+)
+from distributed.protocol.utils import pack_frames, unpack_frames
from .proxify_device_objects import proxify_device_objects, unproxify_device_objects
from .proxy_object import ProxyObject
-class UnspilledProxies:
- """Class to track current unspilled proxies"""
+class Proxies(abc.ABC):
+ """Abstract base class to implement tracking of proxies
+
+ This class is not threadsafe
+ """
def __init__(self):
- self.dev_mem_usage = 0
- self.proxy_id_to_dev_mems: DefaultDict[int, Set[Hashable]] = defaultdict(set)
+ self._proxy_id_to_proxy: Dict[int, ReferenceType[ProxyObject]] = {}
+ self._mem_usage = 0
+
+ def __len__(self) -> int:
+ return len(self._proxy_id_to_proxy)
+
+ @abc.abstractmethod
+ def mem_usage_add(self, proxy: ProxyObject) -> None:
+ """Given a new proxy, update `self._mem_usage`"""
+
+ @abc.abstractmethod
+ def mem_usage_remove(self, proxy: ProxyObject) -> None:
+ """Removal of proxy, update `self._mem_usage`"""
+
+ def add(self, proxy: ProxyObject) -> None:
+ """Add a proxy for tracking, calls `self.mem_usage_add`"""
+ assert not self.contains_proxy_id(id(proxy))
+ self._proxy_id_to_proxy[id(proxy)] = weakref.ref(proxy)
+ self.mem_usage_add(proxy)
+
+ def remove(self, proxy: ProxyObject) -> None:
+ """Remove proxy from tracking, calls `self.mem_usage_remove`"""
+ del self._proxy_id_to_proxy[id(proxy)]
+ self.mem_usage_remove(proxy)
+ if len(self._proxy_id_to_proxy) == 0:
+ if self._mem_usage != 0:
+ warnings.warn(
+ "ProxyManager is empty but the tally of "
+ f"{self} is {self._mem_usage} bytes. "
+ "Resetting the tally."
+ )
+ self._mem_usage = 0
+
+ def __iter__(self) -> Iterator[ProxyObject]:
+ for p in self._proxy_id_to_proxy.values():
+ ret = p()
+ if ret is not None:
+ yield ret
+
+ def contains_proxy_id(self, proxy_id: int) -> bool:
+ return proxy_id in self._proxy_id_to_proxy
+
+ def mem_usage(self) -> int:
+ return self._mem_usage
+
+
+class ProxiesOnHost(Proxies):
+ """Implement tracking of proxies on the CPU
+
+ This uses dask.sizeof to update memory usage.
+ """
+
+ def mem_usage_add(self, proxy: ProxyObject):
+ self._mem_usage += sizeof(proxy)
+
+ def mem_usage_remove(self, proxy: ProxyObject):
+ self._mem_usage -= sizeof(proxy)
+
+
+class ProxiesOnDisk(ProxiesOnHost):
+ """Implement tracking of proxies on the Disk"""
+
+
+class ProxiesOnDevice(Proxies):
+ """Implement tracking of proxies on the GPU
+
+ This is a bit more complicated than ProxiesOnHost because we have to
+ handle that multiple proxy objects can refer to the same underlying
+ device memory object. Thus, we have to track aliasing and make sure
+ we don't count down the memory usage prematurely.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.proxy_id_to_dev_mems: Dict[int, Set[Hashable]] = {}
self.dev_mem_to_proxy_ids: DefaultDict[Hashable, Set[int]] = defaultdict(set)
- def add(self, proxy: ProxyObject):
+ def mem_usage_add(self, proxy: ProxyObject):
proxy_id = id(proxy)
- if proxy_id not in self.proxy_id_to_dev_mems:
- for dev_mem in proxy._obj_pxy_get_device_memory_objects():
- self.proxy_id_to_dev_mems[proxy_id].add(dev_mem)
- ps = self.dev_mem_to_proxy_ids[dev_mem]
- if len(ps) == 0:
- self.dev_mem_usage += sizeof(dev_mem)
- ps.add(proxy_id)
-
- def remove(self, proxy: ProxyObject):
+ assert proxy_id not in self.proxy_id_to_dev_mems
+ self.proxy_id_to_dev_mems[proxy_id] = set()
+ for dev_mem in proxy._obj_pxy_get_device_memory_objects():
+ self.proxy_id_to_dev_mems[proxy_id].add(dev_mem)
+ ps = self.dev_mem_to_proxy_ids[dev_mem]
+ if len(ps) == 0:
+ self._mem_usage += sizeof(dev_mem)
+ ps.add(proxy_id)
+
+ def mem_usage_remove(self, proxy: ProxyObject):
proxy_id = id(proxy)
- if proxy_id in self.proxy_id_to_dev_mems:
- for dev_mem in self.proxy_id_to_dev_mems.pop(proxy_id):
- self.dev_mem_to_proxy_ids[dev_mem].remove(proxy_id)
- if len(self.dev_mem_to_proxy_ids[dev_mem]) == 0:
- del self.dev_mem_to_proxy_ids[dev_mem]
- self.dev_mem_usage -= sizeof(dev_mem)
-
- def __iter__(self):
- return iter(self.proxy_id_to_dev_mems)
+ for dev_mem in self.proxy_id_to_dev_mems.pop(proxy_id):
+ self.dev_mem_to_proxy_ids[dev_mem].remove(proxy_id)
+ if len(self.dev_mem_to_proxy_ids[dev_mem]) == 0:
+ del self.dev_mem_to_proxy_ids[dev_mem]
+ self._mem_usage -= sizeof(dev_mem)
-class ProxiesTally:
+class ProxyManager:
"""
- This class together with UnspilledProxies implements the tracking of current
- objects in device memory and the total memory usage. It turns out having to
- re-calculate device memory usage continuously is too expensive.
-
- We have to track four events:
- - When adding a new key to the host file
- - When removing a key from the host file
- - When a proxy in the host file is deserialized
- - When a proxy in the host file is serialized
-
- However, it gets a bit complicated because:
- - The value of a key in the host file can contain many proxy objects and a single
- proxy object can be referred from many keys
- - Multiple proxy objects can refer to the same underlying device memory object
- - Proxy objects are not hashable thus we have to use the `id()` as key in
- dictionaries
-
- ProxiesTally and UnspilledProxies implements this by carefully maintaining
- dictionaries that maps to/from keys, proxy objects, and device memory objects.
+ This class together with Proxies, ProxiesOnHost, and ProxiesOnDevice
+ implements the tracking of all known proxies and their total host/device
+ memory usage. It turns out having to re-calculate memory usage continuously
+ is too expensive.
+
+ The idea is to have the ProxifyHostFile or the proxies themselves update
+ their location (device or host). The manager then tallies the total memory usage.
+
+ Notice, the manager only keeps weak references to the proxies.
"""
- def __init__(self):
+ def __init__(self, device_memory_limit: int, memory_limit: int):
self.lock = threading.RLock()
- self.proxy_id_to_proxy: Dict[int, ProxyObject] = {}
- self.key_to_proxy_ids: DefaultDict[Hashable, Set[int]] = defaultdict(set)
- self.proxy_id_to_keys: DefaultDict[int, Set[Hashable]] = defaultdict(set)
- self.unspilled_proxies = UnspilledProxies()
+ self._disk = ProxiesOnDisk()
+ self._host = ProxiesOnHost()
+ self._dev = ProxiesOnDevice()
+ self._device_memory_limit = device_memory_limit
+ self._host_memory_limit = memory_limit
- def add_key(self, key, proxies: List[ProxyObject]):
+ def __repr__(self) -> str:
with self.lock:
- for proxy in proxies:
- proxy_id = id(proxy)
- self.proxy_id_to_proxy[proxy_id] = proxy
- self.key_to_proxy_ids[key].add(proxy_id)
- self.proxy_id_to_keys[proxy_id].add(key)
- if not proxy._obj_pxy_is_serialized():
- self.unspilled_proxies.add(proxy)
-
- def del_key(self, key):
+ return (
+ f""
+ )
+
+ def __len__(self) -> int:
+ return len(self._disk) + len(self._host) + len(self._dev)
+
+ def pprint(self) -> str:
with self.lock:
- for proxy_id in self.key_to_proxy_ids.pop(key, ()):
- self.proxy_id_to_keys[proxy_id].remove(key)
- if len(self.proxy_id_to_keys[proxy_id]) == 0:
- del self.proxy_id_to_keys[proxy_id]
- self.unspilled_proxies.remove(self.proxy_id_to_proxy.pop(proxy_id))
+ ret = f"{self}:"
+ if len(self) == 0:
+ return ret + " Empty"
+ ret += "\n"
+ for proxy in self._disk:
+ ret += f" disk - {repr(proxy)}\n"
+ for proxy in self._host:
+ ret += f" host - {repr(proxy)}\n"
+ for proxy in self._dev:
+ ret += f" dev - {repr(proxy)}\n"
+ return ret[:-1] # Strip last newline
+
+ def get_proxies_by_serializer(self, serializer: Optional[str]) -> Proxies:
+ if serializer == "disk":
+ return self._disk
+ elif serializer in ("dask", "pickle"):
+ return self._host
+ else:
+ return self._dev
- def spill_proxy(self, proxy: ProxyObject):
+ def contains(self, proxy_id: int) -> bool:
with self.lock:
- self.unspilled_proxies.remove(proxy)
+ return (
+ self._disk.contains_proxy_id(proxy_id)
+ or self._host.contains_proxy_id(proxy_id)
+ or self._dev.contains_proxy_id(proxy_id)
+ )
- def unspill_proxy(self, proxy: ProxyObject):
+ def add(self, proxy: ProxyObject) -> None:
with self.lock:
- self.unspilled_proxies.add(proxy)
+ if not self.contains(id(proxy)):
+ self.get_proxies_by_serializer(proxy._obj_pxy["serializer"]).add(proxy)
- def get_unspilled_proxies(self) -> Iterator[ProxyObject]:
+ def remove(self, proxy: ProxyObject) -> None:
with self.lock:
- for proxy_id in self.unspilled_proxies:
- ret = self.proxy_id_to_proxy[proxy_id]
- assert not ret._obj_pxy_is_serialized()
- yield ret
+ # Find where the proxy is located and remove it
+ proxies: Optional[Proxies] = None
+ if self._disk.contains_proxy_id(id(proxy)):
+ proxies = self._disk
+ if self._host.contains_proxy_id(id(proxy)):
+ proxies = self._host
+ if self._dev.contains_proxy_id(id(proxy)):
+ assert proxies is None, "Proxy in multiple locations"
+ proxies = self._dev
+ assert proxies is not None, "Trying to remove unknown proxy"
+ proxies.remove(proxy)
+
+ def move(
+ self,
+ proxy: ProxyObject,
+ from_serializer: Optional[str],
+ to_serializer: Optional[str],
+ ) -> None:
+ with self.lock:
+ src = self.get_proxies_by_serializer(from_serializer)
+ dst = self.get_proxies_by_serializer(to_serializer)
+ if src is not dst:
+ src.remove(proxy)
+ dst.add(proxy)
+
+ def validate(self):
+ with self.lock:
+ for serializer in ("disk", "dask", "cuda"):
+ proxies = self.get_proxies_by_serializer(serializer)
+ for p in proxies:
+ assert (
+ self.get_proxies_by_serializer(p._obj_pxy["serializer"])
+ is proxies
+ )
+ for i, p in proxies._proxy_id_to_proxy.items():
+ assert p() is not None
+ assert i == id(p())
+ for p in proxies:
+ if p._obj_pxy_is_serialized():
+ header, _ = p._obj_pxy["obj"]
+ assert header["serializer"] == p._obj_pxy["serializer"]
+
+ def proxify(self, obj: object) -> object:
+ with self.lock:
+ found_proxies: List[ProxyObject] = []
+ proxied_id_to_proxy: Dict[int, ProxyObject] = {}
+ ret = proxify_device_objects(obj, proxied_id_to_proxy, found_proxies)
+ last_access = time.monotonic()
+ for p in found_proxies:
+ p._obj_pxy["last_access"] = last_access
+ if not self.contains(id(p)):
+ p._obj_pxy_register_manager(self)
+ self.add(p)
+ self.maybe_evict()
+ return ret
- def get_proxied_id_to_proxy(self) -> Dict[int, ProxyObject]:
- return {id(p._obj_pxy["obj"]): p for p in self.get_unspilled_proxies()}
+ def get_dev_buffer_to_proxies(self) -> DefaultDict[Hashable, List[ProxyObject]]:
+ with self.lock:
+ # Notice, multiple proxy object can point to different non-overlapping
+ # parts of the same device buffer.
+ ret = defaultdict(list)
+ for proxy in self._dev:
+ for dev_buffer in proxy._obj_pxy_get_device_memory_objects():
+ ret[dev_buffer].append(proxy)
+ return ret
- def get_dev_mem_usage(self) -> int:
- return self.unspilled_proxies.dev_mem_usage
+ def get_dev_access_info(
+ self,
+ ) -> Tuple[int, List[Tuple[int, int, List[ProxyObject]]]]:
+ with self.lock:
+ total_dev_mem_usage = 0
+ dev_buf_access = []
+ for dev_buf, proxies in self.get_dev_buffer_to_proxies().items():
+ last_access = max(p._obj_pxy.get("last_access", 0) for p in proxies)
+ size = sizeof(dev_buf)
+ dev_buf_access.append((last_access, size, proxies))
+ total_dev_mem_usage += size
+ assert total_dev_mem_usage == self._dev.mem_usage()
+ return total_dev_mem_usage, dev_buf_access
+
+ def get_host_access_info(self) -> Tuple[int, List[Tuple[int, int, ProxyObject]]]:
+ with self.lock:
+ total_mem_usage = 0
+ access_info = []
+ for p in self._host:
+ size = sizeof(p)
+ access_info.append((p._obj_pxy.get("last_access", 0), size, p))
+ total_mem_usage += size
+ return total_mem_usage, access_info
+
+ def maybe_evict_from_device(self, extra_dev_mem=0) -> None:
+ if ( # Shortcut when not evicting
+ self._dev.mem_usage() + extra_dev_mem <= self._device_memory_limit
+ ):
+ return
+
+ with self.lock:
+ total_dev_mem_usage, dev_buf_access = self.get_dev_access_info()
+ total_dev_mem_usage += extra_dev_mem
+ if total_dev_mem_usage > self._device_memory_limit:
+ dev_buf_access.sort(key=lambda x: (x[0], -x[1]))
+ for _, size, proxies in dev_buf_access:
+ for p in proxies:
+ # Serialize to disk, which "dask" and "pickle" does
+ p._obj_pxy_serialize(serializers=("dask", "pickle"))
+ total_dev_mem_usage -= size
+ if total_dev_mem_usage <= self._device_memory_limit:
+ break
+
+ def maybe_evict_from_host(self, extra_host_mem=0) -> None:
+ if ( # Shortcut when not evicting
+ self._host.mem_usage() + extra_host_mem <= self._host_memory_limit
+ ):
+ return
+
+ with self.lock:
+ total_host_mem_usage, info = self.get_host_access_info()
+ total_host_mem_usage += extra_host_mem
+ if total_host_mem_usage > self._host_memory_limit:
+ info.sort(key=lambda x: (x[0], -x[1]))
+ for _, size, proxy in info:
+ ProxifyHostFile.serialize_proxy_to_disk_inplace(proxy)
+ total_host_mem_usage -= size
+ if total_host_mem_usage <= self._host_memory_limit:
+ break
+
+ def force_evict_from_host(self) -> int:
+ with self.lock:
+ _, info = self.get_host_access_info()
+ info.sort(key=lambda x: (x[0], -x[1]))
+ for _, size, proxy in info:
+ ProxifyHostFile.serialize_proxy_to_disk_inplace(proxy)
+ return size
+ return 0
+
+ def maybe_evict(self, extra_dev_mem=0) -> None:
+ self.maybe_evict_from_device(extra_dev_mem)
+ self.maybe_evict_from_host()
class ProxifyHostFile(MutableMapping):
@@ -143,20 +372,49 @@ class ProxifyHostFile(MutableMapping):
----------
device_memory_limit: int
Number of bytes of CUDA device memory used before spilling to host.
- compatibility_mode: bool or None
+ memory_limit: int
+ Number of bytes of host memory used before spilling to disk.
+ local_directory: str or None, default None
+ Path on local machine to store temporary files. Can be a string (like
+ ``"path/to/files"``) or ``None`` to fall back on the value of
+ ``dask.temporary-directory`` in the local Dask configuration, using the
+ current working directory if this is not set.
+ WARNING, this **cannot** change while running thus all serialization to
+ disk are using the same directory.
+ shared_filesystem: bool or None, default None
+ Whether the `local_directory` above is shared between all workers or not.
+ If ``None``, the "jit-unspill-shared-fs" config value are used, which
+ defaults to False.
+ Notice, a shared filesystem must support the `os.link()` operation.
+ compatibility_mode: bool or None, default None
Enables compatibility-mode, which means that items are un-proxified before
retrieval. This makes it possible to get some of the JIT-unspill benefits
without having to be ProxyObject compatible. In order to still allow specific
ProxyObjects, set the `mark_as_explicit_proxies=True` when proxifying with
- `proxify_device_objects()`. If None, the "jit-unspill-compatibility-mode"
+ `proxify_device_objects()`. If ``None``, the "jit-unspill-compatibility-mode"
config value are used, which defaults to False.
"""
- def __init__(self, device_memory_limit: int, compatibility_mode: bool = None):
- self.device_memory_limit = device_memory_limit
- self.store = {}
- self.lock = threading.RLock()
- self.proxies_tally = ProxiesTally()
+ # Notice, we define the following as static variables because they are used by
+ # the static register_disk_spilling() method.
+ _spill_directory: Optional[str] = None
+ _spill_shared_filesystem: bool
+ _spill_to_disk_prefix: str = f"spilled-data-{uuid.uuid4()}"
+ _spill_to_disk_counter: int = 0
+ lock = threading.RLock()
+
+ def __init__(
+ self,
+ *,
+ device_memory_limit: int,
+ memory_limit: int,
+ local_directory: str = None,
+ shared_filesystem: bool = None,
+ compatibility_mode: bool = None,
+ ):
+ self.store: Dict[Hashable, Any] = {}
+ self.manager = ProxyManager(device_memory_limit, memory_limit)
+ self.register_disk_spilling(local_directory, shared_filesystem)
if compatibility_mode is None:
self.compatibility_mode = dask.config.get(
"jit-unspill-compatibility-mode", default=False
@@ -164,6 +422,12 @@ def __init__(self, device_memory_limit: int, compatibility_mode: bool = None):
else:
self.compatibility_mode = compatibility_mode
+ # It is a bit hacky to forcefully capture the "distributed.worker" logger,
+ # eventually it would be better to have a different logger. For now this
+ # is ok, allowing users to read logs with client.get_worker_logs(), a
+ # proper solution would require changes to Distributed.
+ self.logger = logging.getLogger("distributed.worker")
+
def __contains__(self, key):
return key in self.store
@@ -174,122 +438,163 @@ def __iter__(self):
with self.lock:
return iter(self.store)
- def get_dev_buffer_to_proxies(self) -> DefaultDict[Hashable, List[ProxyObject]]:
- with self.lock:
- # Notice, multiple proxy object can point to different non-overlapping
- # parts of the same device buffer.
- ret = defaultdict(list)
- for proxy in self.proxies_tally.get_unspilled_proxies():
- for dev_buffer in proxy._obj_pxy_get_device_memory_objects():
- ret[dev_buffer].append(proxy)
- return ret
+ @property
+ def fast(self):
+ """Dask use this to trigger CPU-to-Disk spilling"""
+ if len(self.manager._host) == 0:
+ return False # We have nothing in host memory to spill
- def get_access_info(self) -> Tuple[int, List[Tuple[int, int, List[ProxyObject]]]]:
- with self.lock:
- total_dev_mem_usage = 0
- dev_buf_access = []
- for dev_buf, proxies in self.get_dev_buffer_to_proxies().items():
- last_access = max(p._obj_pxy.get("last_access", 0) for p in proxies)
- size = sizeof(dev_buf)
- dev_buf_access.append((last_access, size, proxies))
- total_dev_mem_usage += size
- return total_dev_mem_usage, dev_buf_access
-
- def add_external(self, obj):
- """Add an external object to the hostfile that count against the
- device_memory_limit but isn't part of the store.
-
- Normally, we use __setitem__ to store objects in the hostfile and make it
- count against the device_memory_limit with the inherent consequence that
- the objects are not freeable before subsequential calls to __delitem__.
- This is a problem for long running tasks that want objects to count against
- the device_memory_limit while freeing them ASAP without explicit calls to
- __delitem__.
-
- Developer Notes
- ---------------
- In order to avoid holding references to the found proxies in `obj`, we
- wrap them in `weakref.proxy(p)` and adds them to the `proxies_tally`.
- In order to remove them from the `proxies_tally` again, we attach a
- finalize(p) on the wrapped proxies that calls del_external().
- """
-
- # Notice, since `self.store` isn't modified, no lock is needed
- found_proxies: List[ProxyObject] = []
- proxied_id_to_proxy = {}
- # Notice, we are excluding found objects that are already proxies
- ret = proxify_device_objects(
- obj, proxied_id_to_proxy, found_proxies, excl_proxies=True
- )
- last_access = time.monotonic()
- self_weakref = weakref.ref(self)
- for p in found_proxies:
- name = id(p)
- finalize = weakref.finalize(p, self.del_external, name)
- external = weakref.proxy(p)
- p._obj_pxy["hostfile"] = self_weakref
- p._obj_pxy["last_access"] = last_access
- p._obj_pxy["external"] = external
- p._obj_pxy["external_finalize"] = finalize
- self.proxies_tally.add_key(name, [external])
- self.maybe_evict()
- return ret
+ class EvictDummy:
+ @staticmethod
+ def evict():
+ return None, None, self.manager.force_evict_from_host()
- def del_external(self, name):
- self.proxies_tally.del_key(name)
+ return EvictDummy()
def __setitem__(self, key, value):
with self.lock:
if key in self.store:
# Make sure we register the removal of an existing key
del self[key]
-
- found_proxies: List[ProxyObject] = []
- proxied_id_to_proxy = self.proxies_tally.get_proxied_id_to_proxy()
- self.store[key] = proxify_device_objects(
- value, proxied_id_to_proxy, found_proxies
- )
- last_access = time.monotonic()
- self_weakref = weakref.ref(self)
- for p in found_proxies:
- p._obj_pxy["hostfile"] = self_weakref
- p._obj_pxy["last_access"] = last_access
- assert "external" not in p._obj_pxy
-
- self.proxies_tally.add_key(key, found_proxies)
- self.maybe_evict()
+ self.store[key] = self.manager.proxify(value)
def __getitem__(self, key):
with self.lock:
ret = self.store[key]
if self.compatibility_mode:
ret = unproxify_device_objects(ret, skip_explicit_proxies=True)
- self.maybe_evict()
+ self.manager.maybe_evict()
return ret
def __delitem__(self, key):
with self.lock:
del self.store[key]
- self.proxies_tally.del_key(key)
- def evict(self, proxy: ProxyObject):
- proxy._obj_pxy_serialize(serializers=("dask", "pickle"))
+ @classmethod
+ def gen_file_path(cls) -> str:
+ """Generate an unique file path"""
+ with cls.lock:
+ cls._spill_to_disk_counter += 1
+ assert cls._spill_directory is not None
+ return os.path.join(
+ cls._spill_directory,
+ f"{cls._spill_to_disk_prefix}-{cls._spill_to_disk_counter}",
+ )
- def maybe_evict(self, extra_dev_mem=0):
- if ( # Shortcut when not evicting
- self.proxies_tally.get_dev_mem_usage() + extra_dev_mem
- <= self.device_memory_limit
- ):
- return
+ @classmethod
+ def register_disk_spilling(
+ cls, local_directory: str = None, shared_filesystem: bool = None
+ ):
+ """Register Dask serializers that writes to disk
+
+ This is a static method because the registration of a Dask
+ serializer/deserializer pair is a global operation thus we can
+ only register one such pair. This means that all instances of
+ the ``ProxifyHostFile`` end up using the same ``local_directory``.
+
+ Parameters
+ ----------
+ local_directory : str or None, default None
+ Path to the root directory to write serialized data.
+ Can be a string or None to fall back on the value of
+ ``dask.temporary-directory`` in the local Dask configuration,
+ using the current working directory if this is not set.
+ WARNING, this **cannot** change while running thus all
+ serialization to disk are using the same directory.
+ shared_filesystem: bool or None, default None
+ Whether the `local_directory` above is shared between all workers or not.
+ If ``None``, the "jit-unspill-shared-fs" config value are used, which
+ defaults to False.
+ """
+ path = os.path.join(
+ local_directory or dask.config.get("temporary-directory") or os.getcwd(),
+ "dask-worker-space",
+ "jit-unspill-disk-storage",
+ )
+ if cls._spill_directory is None:
+ cls._spill_directory = path
+ elif cls._spill_directory != path:
+ raise ValueError("Cannot change the JIT-Unspilling disk path")
+ os.makedirs(cls._spill_directory, exist_ok=True)
+
+ if shared_filesystem is None:
+ cls._spill_shared_filesystem = dask.config.get(
+ "jit-unspill-shared-fs", default=False
+ )
+ else:
+ cls._spill_shared_filesystem = shared_filesystem
+
+ def disk_dumps(x):
+ header, frames = serialize_and_split(x, on_error="raise")
+ if frames:
+ compression, frames = zip(*map(maybe_compress, frames))
+ else:
+ compression = []
+ header["compression"] = compression
+ header["count"] = len(frames)
+
+ path = cls.gen_file_path()
+ with open(path, "wb") as f:
+ f.write(pack_frames(frames))
+ return (
+ {
+ "serializer": "disk",
+ "path": path,
+ "shared-filesystem": cls._spill_shared_filesystem,
+ "disk-sub-header": header,
+ },
+ [],
+ )
- with self.lock:
- total_dev_mem_usage, dev_buf_access = self.get_access_info()
- total_dev_mem_usage += extra_dev_mem
- if total_dev_mem_usage > self.device_memory_limit:
- dev_buf_access.sort(key=lambda x: (x[0], -x[1]))
- for _, size, proxies in dev_buf_access:
- for p in proxies:
- self.evict(p)
- total_dev_mem_usage -= size
- if total_dev_mem_usage <= self.device_memory_limit:
- break
+ def disk_loads(header, frames):
+ assert frames == []
+ with open(header["path"], "rb") as f:
+ frames = unpack_frames(f.read())
+ os.remove(header["path"])
+ if "compression" in header["disk-sub-header"]:
+ frames = decompress(header["disk-sub-header"], frames)
+ return merge_and_deserialize(header["disk-sub-header"], frames)
+
+ register_serialization_family("disk", disk_dumps, disk_loads)
+
+ @classmethod
+ def serialize_proxy_to_disk_inplace(cls, proxy: ProxyObject):
+ """Serialize `proxy` to disk.
+
+ Avoid de-serializing if `proxy` is serialized using "dask" or
+ "pickle". In this case the already serialized data is written
+ directly to disk.
+
+ Parameters
+ ----------
+ proxy : ProxyObject
+ Proxy object to serialize using the "disk" serialize.
+ """
+ manager = proxy._obj_pxy_get_manager()
+ with manager.lock:
+ if not proxy._obj_pxy_is_serialized():
+ proxy._obj_pxy_serialize(serializers=("disk",))
+ else:
+ header, frames = proxy._obj_pxy["obj"]
+ if header["serializer"] in ("dask", "pickle"):
+ path = cls.gen_file_path()
+ with open(path, "wb") as f:
+ f.write(pack_frames(frames))
+ proxy._obj_pxy["obj"] = (
+ {
+ "serializer": "disk",
+ "path": path,
+ "shared-filesystem": cls._spill_shared_filesystem,
+ "disk-sub-header": header,
+ },
+ [],
+ )
+ proxy._obj_pxy["serializer"] = "disk"
+ manager.move(
+ proxy,
+ from_serializer=header["serializer"],
+ to_serializer="disk",
+ )
+ elif header["serializer"] != "disk":
+ proxy._obj_pxy_deserialize()
+ proxy._obj_pxy_serialize(serializers=("disk",))
diff --git a/dask_cuda/proxy_object.py b/dask_cuda/proxy_object.py
index 649f400ed..86eb255af 100644
--- a/dask_cuda/proxy_object.py
+++ b/dask_cuda/proxy_object.py
@@ -1,11 +1,14 @@
import copy
import functools
import operator
+import os
import pickle
import threading
import time
+import uuid
from collections import OrderedDict
-from typing import Any, Dict, List, Optional, Set
+from contextlib import nullcontext
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Type, Union
import pandas
@@ -13,9 +16,12 @@
import dask.array.core
import dask.dataframe.methods
import dask.dataframe.utils
+import dask.utils
import distributed.protocol
import distributed.utils
from dask.sizeof import sizeof
+from distributed.protocol.compression import decompress
+from distributed.protocol.utils import unpack_frames
from distributed.worker import dumps_function, loads_function
try:
@@ -31,21 +37,26 @@
from .get_device_memory_objects import get_device_memory_objects
from .is_device_object import is_device_object
+if TYPE_CHECKING:
+ from .proxify_host_file import ProxyManager
+
+
# List of attributes that should be copied to the proxy at creation, which makes
# them accessible without deserialization of the proxied object
_FIXED_ATTRS = ["name", "__len__"]
-def asproxy(obj, serializers=None, subclass=None) -> "ProxyObject":
+def asproxy(
+ obj: object, serializers: Iterable[str] = None, subclass: Type["ProxyObject"] = None
+) -> "ProxyObject":
"""Wrap `obj` in a ProxyObject object if it isn't already.
Parameters
----------
obj: object
Object to wrap in a ProxyObject object.
- serializers: list(str), optional
- List of serializers to use to serialize `obj`. If None,
- no serialization is done.
+ serializers: Iterable[str], optional
+ Serializers to use to serialize `obj`. If None, no serialization is done.
subclass: class, optional
Specify a subclass of ProxyObject to create instead of ProxyObject.
`subclass` must be pickable.
@@ -54,9 +65,10 @@ def asproxy(obj, serializers=None, subclass=None) -> "ProxyObject":
-------
The ProxyObject proxying `obj`
"""
-
- if hasattr(obj, "_obj_pxy"): # Already a proxy object
+ if isinstance(obj, ProxyObject): # Already a proxy object
ret = obj
+ elif isinstance(obj, (list, set, tuple, dict)):
+ raise ValueError(f"Cannot wrap a collection ({type(obj)}) in a proxy object")
else:
fixed_attr = {}
for attr in _FIXED_ATTRS:
@@ -81,7 +93,7 @@ def asproxy(obj, serializers=None, subclass=None) -> "ProxyObject":
typename=dask.utils.typename(type(obj)),
is_cuda_object=is_device_object(obj),
subclass=subclass_serialized,
- serializers=None,
+ serializer=None,
explicit_proxy=False,
)
if serializers is not None:
@@ -112,7 +124,7 @@ def unproxy(obj):
return obj
-def _obj_pxy_cache_wrapper(attr_name):
+def _obj_pxy_cache_wrapper(attr_name: str):
"""Caching the access of attr_name in ProxyObject._obj_pxy_cache"""
def wrapper1(func):
@@ -130,6 +142,31 @@ def wrapper2(self: "ProxyObject"):
return wrapper1
+class ProxyManagerDummy:
+ """Dummy of a ProxyManager that does nothing
+
+ This is a dummy class returned by `ProxyObject._obj_pxy_get_manager()`
+ when no manager has been registered the proxy object. It implements
+ dummy methods that doesn't do anything it is purely for convenience.
+ """
+
+ def add(self, *args, **kwargs):
+ pass
+
+ def remove(self, *args, **kwargs):
+ pass
+
+ def move(self, *args, **kwargs):
+ pass
+
+ def maybe_evict(self, *args, **kwargs):
+ pass
+
+ @property
+ def lock(self):
+ return nullcontext()
+
+
class ProxyObject:
"""Object wrapper/proxy for serializable objects
@@ -183,9 +220,8 @@ class ProxyObject:
subclass: bytes
Pickled type to use instead of ProxyObject when deserializing. The type
must inherit from ProxyObject.
- serializers: list(str), optional
- List of serializers to use to serialize `obj`. If None, `obj`
- isn't serialized.
+ serializers: str, optional
+ Serializers to use to serialize `obj`. If None, no serialization is done.
explicit_proxy: bool
Mark the proxy object as "explicit", which means that the user allows it
as input argument to dask tasks even in compatibility-mode.
@@ -198,8 +234,8 @@ def __init__(
type_serialized: bytes,
typename: str,
is_cuda_object: bool,
- subclass: bytes,
- serializers: Optional[List[str]],
+ subclass: Optional[bytes],
+ serializer: Optional[str],
explicit_proxy: bool,
):
self._obj_pxy = {
@@ -209,19 +245,20 @@ def __init__(
"typename": typename,
"is_cuda_object": is_cuda_object,
"subclass": subclass,
- "serializers": serializers,
+ "serializer": serializer,
"explicit_proxy": explicit_proxy,
}
self._obj_pxy_lock = threading.RLock()
- self._obj_pxy_cache = {}
+ self._obj_pxy_cache: Dict[str, Any] = {}
def __del__(self):
- """In order to call `external_finalize()` ASAP, we call it here"""
- external_finalize = self._obj_pxy.get("external_finalize", None)
- if external_finalize is not None:
- external_finalize()
+ """We have to unregister us from the manager if any"""
+ self._obj_pxy_get_manager().remove(self)
+ if self._obj_pxy["serializer"] == "disk":
+ header, _ = self._obj_pxy["obj"]
+ os.remove(header["path"])
- def _obj_pxy_get_init_args(self, include_obj=True):
+ def _obj_pxy_get_init_args(self, include_obj=True) -> OrderedDict:
"""Return the attributes needed to initialize a ProxyObject
Notice, the returned dictionary is ordered as the __init__() arguments
@@ -242,7 +279,7 @@ def _obj_pxy_get_init_args(self, include_obj=True):
"typename",
"is_cuda_object",
"subclass",
- "serializers",
+ "serializer",
"explicit_proxy",
]
return OrderedDict([(a, self._obj_pxy[a]) for a in args])
@@ -260,17 +297,48 @@ def _obj_pxy_copy(self) -> "ProxyObject":
args["obj"] = self._obj_pxy["obj"]
return type(self)(**args)
- def _obj_pxy_is_serialized(self):
+ def _obj_pxy_register_manager(self, manager: "ProxyManager") -> None:
+ """Register a manager
+
+ The manager tallies the total memory usage of proxies and
+ evicts/serialize proxy objects as needed.
+
+ In order to prevent deadlocks, the proxy now use the lock of the
+ manager.
+
+ Parameters
+ ----------
+ manager: ProxyManager
+ The manager to manage this proxy object
+ """
+ assert "manager" not in self._obj_pxy
+ self._obj_pxy["manager"] = manager
+ self._obj_pxy_lock = manager.lock
+
+ def _obj_pxy_get_manager(self) -> Union["ProxyManager", ProxyManagerDummy]:
+ """Get the registered manager or a dummy
+
+ Parameters
+ ----------
+ manager: ProxyManager or ProxyManagerDummy
+ The manager to manage this proxy object or a dummy
+ """
+ ret = self._obj_pxy.get("manager", None)
+ if ret is None:
+ ret = ProxyManagerDummy()
+ return ret
+
+ def _obj_pxy_is_serialized(self) -> bool:
"""Return whether the proxied object is serialized or not"""
- return self._obj_pxy["serializers"] is not None
+ return self._obj_pxy["serializer"] is not None
- def _obj_pxy_serialize(self, serializers):
+ def _obj_pxy_serialize(self, serializers: Iterable[str]):
"""Inplace serialization of the proxied object using the `serializers`
Parameters
----------
- serializers: tuple[str]
- Tuple of serializers to use to serialize the proxied object.
+ serializers: Iterable[str]
+ Serializers to use to serialize the proxied object.
Returns
-------
@@ -282,30 +350,29 @@ def _obj_pxy_serialize(self, serializers):
if not serializers:
raise ValueError("Please specify a list of serializers")
- if type(serializers) is not tuple:
- serializers = tuple(serializers)
-
with self._obj_pxy_lock:
- if self._obj_pxy["serializers"] is not None:
- if self._obj_pxy["serializers"] == serializers:
+ if self._obj_pxy_is_serialized():
+ if self._obj_pxy["serializer"] in serializers:
return self._obj_pxy["obj"] # Nothing to be done
else:
# The proxied object is serialized with other serializers
self._obj_pxy_deserialize()
- if self._obj_pxy["serializers"] is None:
- self._obj_pxy["obj"] = distributed.protocol.serialize(
+ manager = self._obj_pxy_get_manager()
+ with manager.lock:
+ header, _ = self._obj_pxy["obj"] = distributed.protocol.serialize(
self._obj_pxy["obj"], serializers, on_error="raise"
)
- self._obj_pxy["serializers"] = serializers
- hostfile = self._obj_pxy.get("hostfile", lambda: None)()
- if hostfile is not None:
- external = self._obj_pxy.get("external", self)
- hostfile.proxies_tally.spill_proxy(external)
-
- # Invalidate the (possible) cached "device_memory_objects"
- self._obj_pxy_cache.pop("device_memory_objects", None)
- return self._obj_pxy["obj"]
+ assert "is-collection" not in header # Collections not allowed
+ org_ser, new_ser = self._obj_pxy["serializer"], header["serializer"]
+ self._obj_pxy["serializer"] = new_ser
+
+ # Tell the manager (if any) that this proxy has changed serializer
+ manager.move(self, from_serializer=org_ser, to_serializer=new_ser)
+
+ # Invalidate the (possible) cached "device_memory_objects"
+ self._obj_pxy_cache.pop("device_memory_objects", None)
+ return self._obj_pxy["obj"]
def _obj_pxy_deserialize(self, maybe_evict: bool = True):
"""Inplace deserialization of the proxied object
@@ -313,7 +380,7 @@ def _obj_pxy_deserialize(self, maybe_evict: bool = True):
Parameters
----------
maybe_evict: bool
- Before deserializing, call associated hostfile.maybe_evict()
+ Before deserializing, maybe evict managered proxy objects
Returns
-------
@@ -321,27 +388,32 @@ def _obj_pxy_deserialize(self, maybe_evict: bool = True):
The proxied object (deserialized)
"""
with self._obj_pxy_lock:
- if self._obj_pxy["serializers"] is not None:
- hostfile = self._obj_pxy.get("hostfile", lambda: None)()
- # When not deserializing a CUDA-serialized proxied, we might have
- # to evict because of the increased device memory usage.
- if maybe_evict and "cuda" not in self._obj_pxy["serializers"]:
- if hostfile is not None:
- # In order to avoid a potential deadlock, we skip the
- # `maybe_evict()` call if another thread is also accessing
- # the hostfile.
- if hostfile.lock.acquire(blocking=False):
- try:
- hostfile.maybe_evict(self.__sizeof__())
- finally:
- hostfile.lock.release()
-
- header, frames = self._obj_pxy["obj"]
- self._obj_pxy["obj"] = distributed.protocol.deserialize(header, frames)
- self._obj_pxy["serializers"] = None
- if hostfile is not None:
- external = self._obj_pxy.get("external", self)
- hostfile.proxies_tally.unspill_proxy(external)
+ if self._obj_pxy_is_serialized():
+ manager = self._obj_pxy_get_manager()
+ with manager.lock:
+ # When not deserializing a CUDA-serialized proxied, tell the
+ # manager that it might have to evict because of the increased
+ # device memory usage.
+ if (
+ manager
+ and maybe_evict
+ and self._obj_pxy["serializer"] != "cuda"
+ ):
+ manager.maybe_evict(self.__sizeof__())
+
+ # Deserialize the proxied object
+ header, frames = self._obj_pxy["obj"]
+ self._obj_pxy["obj"] = distributed.protocol.deserialize(
+ header, frames
+ )
+
+ # Tell the manager (if any) that this proxy has changed serializer
+ manager.move(
+ self,
+ from_serializer=self._obj_pxy["serializer"],
+ to_serializer=None,
+ )
+ self._obj_pxy["serializer"] = None
self._obj_pxy["last_access"] = time.monotonic()
return self._obj_pxy["obj"]
@@ -354,16 +426,12 @@ def _obj_pxy_is_cuda_object(self) -> bool:
ret : boolean
Is the proxied object a CUDA object?
"""
- with self._obj_pxy_lock:
- return self._obj_pxy["is_cuda_object"]
+ return self._obj_pxy["is_cuda_object"]
@_obj_pxy_cache_wrapper("device_memory_objects")
- def _obj_pxy_get_device_memory_objects(self) -> Set:
+ def _obj_pxy_get_device_memory_objects(self) -> set:
"""Return all device memory objects within the proxied object.
- Calling this when the proxied object is serialized returns the
- empty list.
-
Returns
-------
ret : set
@@ -409,6 +477,21 @@ def __setattr__(self, name, val):
else:
object.__setattr__(self._obj_pxy_deserialize(), name, val)
+ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
+ inputs = tuple(
+ o._obj_pxy_deserialize() if isinstance(o, ProxyObject) else o
+ for o in inputs
+ )
+ kwargs = {
+ key: value._obj_pxy_deserialize()
+ if isinstance(value, ProxyObject)
+ else value
+ for key, value in kwargs.items()
+ }
+ return self._obj_pxy_deserialize().__array_ufunc__(
+ ufunc, method, *inputs, **kwargs
+ )
+
def __str__(self):
return str(self._obj_pxy_deserialize())
@@ -416,13 +499,13 @@ def __repr__(self):
with self._obj_pxy_lock:
typename = self._obj_pxy["typename"]
ret = f"<{dask.utils.typename(type(self))} at {hex(id(self))} of {typename}"
- if self._obj_pxy["serializers"] is not None:
- ret += f" (serialized={repr(self._obj_pxy['serializers'])})>"
+ if self._obj_pxy_is_serialized():
+ ret += f" (serialized={repr(self._obj_pxy['serializer'])})>"
else:
ret += f" at {hex(id(self._obj_pxy['obj']))}>"
return ret
- @property
+ @property # type: ignore # mypy doesn't support decorated property
@_obj_pxy_cache_wrapper("type_serialized")
def __class__(self):
return pickle.loads(self._obj_pxy["type_serialized"])
@@ -515,8 +598,8 @@ def __mod__(self, other):
def __divmod__(self, other):
return divmod(self._obj_pxy_deserialize(), other)
- def __pow__(self, other, *args):
- return pow(self._obj_pxy_deserialize(), other, *args)
+ def __pow__(self, other):
+ return pow(self._obj_pxy_deserialize(), other)
def __lshift__(self, other):
return self._obj_pxy_deserialize() << other
@@ -668,27 +751,63 @@ def obj_pxy_is_device_object(obj: ProxyObject):
return obj._obj_pxy_is_cuda_object()
+def handle_disk_serialized(obj: ProxyObject):
+ """Handle serialization of an already disk serialized proxy
+
+ On a shared filesystem, we do not have to deserialize instead we
+ make a hard link of the file.
+
+ On a non-shared filesystem, we deserialize the proxy to host memory.
+ """
+
+ header, frames = obj._obj_pxy["obj"]
+ if header["shared-filesystem"]:
+ old_path = header["path"]
+ new_path = f"{old_path}-linked-{uuid.uuid4()}"
+ os.link(old_path, new_path)
+ header = copy.copy(header)
+ header["path"] = new_path
+ else:
+ # When not on a shared filesystem, we deserialize to host memory
+ assert frames == []
+ with open(header["path"], "rb") as f:
+ frames = unpack_frames(f.read())
+ os.remove(header["path"])
+ if "compression" in header["disk-sub-header"]:
+ frames = decompress(header["disk-sub-header"], frames)
+ header = header["disk-sub-header"]
+ obj._obj_pxy["serializer"] = header["serializer"]
+ return header, frames
+
+
@distributed.protocol.dask_serialize.register(ProxyObject)
def obj_pxy_dask_serialize(obj: ProxyObject):
+ """The dask serialization of ProxyObject used by Dask when communicating using TCP
+
+ As serializers, it uses "dask" or "pickle", which means that proxied CUDA objects
+ are spilled to main memory before communicated. Deserialization is needed, unless
+ obj is serialized to disk on a shared filesystem see `handle_disk_serialized()`.
"""
- The generic serialization of ProxyObject used by Dask when communicating
- ProxyObject. As serializers, it uses "dask" or "pickle", which means
- that proxied CUDA objects are spilled to main memory before communicated.
- """
- header, frames = obj._obj_pxy_serialize(serializers=("dask", "pickle"))
+ if obj._obj_pxy["serializer"] == "disk":
+ header, frames = handle_disk_serialized(obj)
+ else:
+ header, frames = obj._obj_pxy_serialize(serializers=("dask", "pickle"))
meta = obj._obj_pxy_get_init_args(include_obj=False)
return {"proxied-header": header, "obj-pxy-meta": meta}, frames
@distributed.protocol.cuda.cuda_serialize.register(ProxyObject)
def obj_pxy_cuda_serialize(obj: ProxyObject):
+ """ The CUDA serialization of ProxyObject used by Dask when communicating using UCX
+
+ As serializers, it uses "cuda", which means that proxied CUDA objects are _not_
+ spilled to main memory before communicated. However, we still have to handle disk
+ serialized proxied like in `obj_pxy_dask_serialize()`
"""
- The CUDA serialization of ProxyObject used by Dask when communicating using UCX
- or another CUDA friendly communication library. As serializers, it uses "cuda",
- which means that proxied CUDA objects are _not_ spilled to main memory.
- """
- if obj._obj_pxy["serializers"] is not None: # Already serialized
+ if obj._obj_pxy["serializer"] in ("dask", "pickle"):
header, frames = obj._obj_pxy["obj"]
+ elif obj._obj_pxy["serializer"] == "disk":
+ header, frames = handle_disk_serialized(obj)
else:
# Notice, since obj._obj_pxy_serialize() is a inplace operation, we make a
# shallow copy of `obj` to avoid introducing a CUDA-serialized object in
diff --git a/dask_cuda/tests/test_dask_cuda_worker.py b/dask_cuda/tests/test_dask_cuda_worker.py
index 3e6478c89..c4b134b03 100644
--- a/dask_cuda/tests/test_dask_cuda_worker.py
+++ b/dask_cuda/tests/test_dask_cuda_worker.py
@@ -5,14 +5,14 @@
import pytest
-from distributed import Client
+from distributed import Client, wait
from distributed.system import MEMORY_LIMIT
from distributed.utils_test import loop # noqa: F401
from distributed.utils_test import popen
import rmm
-from dask_cuda.utils import get_n_gpus, wait_workers
+from dask_cuda.utils import get_gpu_count_mig, get_n_gpus, wait_workers
_driver_version = rmm._cuda.gpu.driverGetVersion()
_runtime_version = rmm._cuda.gpu.runtimeGetVersion()
@@ -186,3 +186,53 @@ def test_unknown_argument():
ret = subprocess.run(["dask-cuda-worker", "--my-argument"], capture_output=True)
assert ret.returncode != 0
assert b"Scheduler address: --my-argument" in ret.stderr
+
+
+def test_cuda_mig_visible_devices_and_memory_limit_and_nthreads(loop): # noqa: F811
+ init_nvmlstatus = os.environ.get("DASK_DISTRIBUTED__DIAGNOSTICS__NVML")
+ try:
+ os.environ["DASK_DISTRIBUTED__DIAGNOSTICS__NVML"] = "False"
+ uuids = get_gpu_count_mig(return_uuids=True)[1]
+ # test only with some MIG Instances assuming the test bed
+ # does not have a huge number of mig instances
+ if len(uuids) > 0:
+ uuids = [i.decode("utf-8") for i in uuids]
+ else:
+ pytest.skip("No MIG devices found")
+ CUDA_VISIBLE_DEVICES = ",".join(uuids)
+ os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES
+ nthreads = len(CUDA_VISIBLE_DEVICES)
+ with popen(["dask-scheduler", "--port", "9359", "--no-dashboard"]):
+ with popen(
+ [
+ "dask-cuda-worker",
+ "127.0.0.1:9359",
+ "--host",
+ "127.0.0.1",
+ "--nthreads",
+ str(nthreads),
+ "--no-dashboard",
+ "--worker-class",
+ "dask_cuda.utils.MockWorker",
+ ]
+ ):
+ with Client("127.0.0.1:9359", loop=loop) as client:
+ assert wait_workers(client, n_gpus=len(uuids))
+ # Check to see if all workers are up and
+ # CUDA_VISIBLE_DEVICES cycles properly
+
+ def get_visible_devices():
+ return os.environ["CUDA_VISIBLE_DEVICES"]
+
+ result = client.run(get_visible_devices)
+ wait(result)
+ assert all(len(v.split(",")) == len(uuids) for v in result.values())
+ for i in range(len(uuids)):
+ assert set(v.split(",")[i] for v in result.values()) == set(
+ uuids
+ )
+ finally:
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
+ del os.environ["CUDA_VISIBLE_DEVICES"]
+ if init_nvmlstatus:
+ os.environ["DASK_DISTRIBUTED__DIAGNOSTICS__NVML"] = init_nvmlstatus
diff --git a/dask_cuda/tests/test_dgx.py b/dask_cuda/tests/test_dgx.py
index 7c029b7e7..164bf2e5f 100644
--- a/dask_cuda/tests/test_dgx.py
+++ b/dask_cuda/tests/test_dgx.py
@@ -249,6 +249,10 @@ def check_ucx_options():
def test_ucx_infiniband_nvlink(params):
ucp = pytest.importorskip("ucp") # NOQA: F841
+ if params["enable_infiniband"]:
+ if not any([at.startswith("rc") for at in ucp.get_active_transports()]):
+ pytest.skip("No support available for 'rc' transport in UCX")
+
p = mp.Process(
target=_test_ucx_infiniband_nvlink,
args=(
@@ -259,6 +263,12 @@ def test_ucx_infiniband_nvlink(params):
)
p.start()
p.join()
+
+ # Starting a new cluster on the same pytest process after an rdmacm cluster
+ # has been used may cause UCX-Py to complain about being already initialized.
+ if params["enable_rdmacm"] is True:
+ ucp.reset()
+
assert not p.exitcode
@@ -274,13 +284,13 @@ def _test_dask_cuda_worker_ucx_net_devices(enable_rdmacm):
# Enable proper variables for scheduler
sched_env = os.environ.copy()
- sched_env["DASK_UCX__INFINIBAND"] = "True"
- sched_env["DASK_UCX__TCP"] = "True"
- sched_env["DASK_UCX__CUDA_COPY"] = "True"
- sched_env["DASK_UCX__NET_DEVICES"] = openfabrics_devices[0]
+ sched_env["DASK_DISTRIBUTED__COMM__UCX__INFINIBAND"] = "True"
+ sched_env["DASK_DISTRIBUTED__COMM__UCX__TCP"] = "True"
+ sched_env["DASK_DISTRIBUTED__COMM__UCX__CUDA_COPY"] = "True"
+ sched_env["DASK_DISTRIBUTED__COMM__UCX__NET_DEVICES"] = openfabrics_devices[0]
if enable_rdmacm:
- sched_env["DASK_UCX__RDMACM"] = "True"
+ sched_env["DASK_DISTRIBUTED__COMM__UCX__RDMACM"] = "True"
sched_addr = get_ip_interface("ib0")
sched_url = "ucx://" + sched_addr + ":9379"
@@ -370,9 +380,17 @@ def test_dask_cuda_worker_ucx_net_devices(enable_rdmacm):
if _ucx_110:
pytest.skip("UCX 1.10 and higher should rely on default UCX_NET_DEVICES")
+ if not any([at.startswith("rc") for at in ucp.get_active_transports()]):
+ pytest.skip("No support available for 'rc' transport in UCX")
+
p = mp.Process(
target=_test_dask_cuda_worker_ucx_net_devices, args=(enable_rdmacm,),
)
p.start()
p.join()
+
+ # The processes may be killed in the test, preventing UCX-Py from cleaning
+ # up all objects. Reset to prevent issues on tests running after.
+ ucp.reset()
+
assert not p.exitcode
diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py
index 06efe907c..281a930e5 100644
--- a/dask_cuda/tests/test_explicit_comms.py
+++ b/dask_cuda/tests/test_explicit_comms.py
@@ -8,6 +8,7 @@
import dask
from dask import dataframe as dd
from dask.dataframe.shuffle import partitioning_index
+from dask.dataframe.utils import assert_eq
from distributed import Client, get_worker
from distributed.deploy.local import LocalCluster
@@ -15,6 +16,7 @@
from dask_cuda.explicit_comms import comms
from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle
from dask_cuda.initialize import initialize
+from dask_cuda.utils import get_ucx_config
mp = mp.get_context("spawn")
ucp = pytest.importorskip("ucp")
@@ -30,7 +32,7 @@ async def my_rank(state, arg):
def _test_local_cluster(protocol):
dask.config.update(
dask.config.global_config,
- {"ucx": {"tcp": True, "cuda_copy": True,},},
+ {"distributed.comm.ucx": get_ucx_config(enable_tcp_over_ucx=True),},
priority="new",
)
@@ -96,15 +98,11 @@ def check_partitions(df, npartitions):
def _test_dataframe_shuffle(backend, protocol, n_workers):
if backend == "cudf":
cudf = pytest.importorskip("cudf")
- from cudf.testing._utils import assert_eq
-
initialize(enable_tcp_over_ucx=True)
else:
- from dask.dataframe.utils import assert_eq
-
dask.config.update(
dask.config.global_config,
- {"ucx": {"tcp": True, "cuda_copy": True,},},
+ {"distributed.comm.ucx": get_ucx_config(enable_tcp_over_ucx=True),},
priority="new",
)
@@ -143,10 +141,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers):
# Check the values of `ddf` (ignoring the row order)
expected = df.sort_values("key")
got = ddf.compute().sort_values("key")
- if backend == "cudf":
- assert_eq(got, expected)
- else:
- pd.testing.assert_frame_equal(got, expected)
+ assert_eq(got, expected)
@pytest.mark.parametrize("nworkers", [1, 2, 3])
@@ -201,15 +196,13 @@ def test_dask_use_explicit_comms():
def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
if backend == "cudf":
cudf = pytest.importorskip("cudf")
- from cudf.testing._utils import assert_eq
initialize(enable_tcp_over_ucx=True)
else:
- from dask.dataframe.utils import assert_eq
dask.config.update(
dask.config.global_config,
- {"ucx": {"tcp": True, "cuda_copy": True,},},
+ {"distributed.comm.ucx": get_ucx_config(enable_tcp_over_ucx=True),},
priority="new",
)
@@ -242,10 +235,7 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
)
with dask.config.set(explicit_comms=True):
got = ddf1.merge(ddf2, on="key").set_index("key").compute()
- if backend == "cudf":
- assert_eq(got, expected)
- else:
- pd.testing.assert_frame_equal(got, expected)
+ assert_eq(got, expected)
@pytest.mark.parametrize("nworkers", [1, 2, 4])
@@ -264,7 +254,6 @@ def test_dataframe_shuffle_merge(backend, protocol, nworkers):
def _test_jit_unspill(protocol):
import cudf
- from cudf.testing._utils import assert_eq
with dask_cuda.LocalCUDACluster(
protocol=protocol,
diff --git a/dask_cuda/tests/test_initialize.py b/dask_cuda/tests/test_initialize.py
index cb99de1be..f26351e4c 100644
--- a/dask_cuda/tests/test_initialize.py
+++ b/dask_cuda/tests/test_initialize.py
@@ -29,7 +29,7 @@ def _test_initialize_ucx_tcp():
n_workers=1,
threads_per_worker=1,
processes=True,
- config={"ucx": get_ucx_config(**kwargs)},
+ config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
) as cluster:
with Client(cluster) as client:
res = da.from_array(numpy.arange(10000), chunks=(1000,))
@@ -68,7 +68,7 @@ def _test_initialize_ucx_nvlink():
n_workers=1,
threads_per_worker=1,
processes=True,
- config={"ucx": get_ucx_config(**kwargs)},
+ config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
) as cluster:
with Client(cluster) as client:
res = da.from_array(numpy.arange(10000), chunks=(1000,))
@@ -110,7 +110,7 @@ def _test_initialize_ucx_infiniband():
n_workers=1,
threads_per_worker=1,
processes=True,
- config={"ucx": get_ucx_config(**kwargs)},
+ config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
) as cluster:
with Client(cluster) as client:
res = da.from_array(numpy.arange(10000), chunks=(1000,))
diff --git a/dask_cuda/tests/test_local_cuda_cluster.py b/dask_cuda/tests/test_local_cuda_cluster.py
index 1d5af958b..464304f76 100644
--- a/dask_cuda/tests/test_local_cuda_cluster.py
+++ b/dask_cuda/tests/test_local_cuda_cluster.py
@@ -10,7 +10,7 @@
from dask_cuda import CUDAWorker, LocalCUDACluster, utils
from dask_cuda.initialize import initialize
-from dask_cuda.utils import MockWorker
+from dask_cuda.utils import MockWorker, get_gpu_count_mig
_driver_version = rmm._cuda.gpu.driverGetVersion()
_runtime_version = rmm._cuda.gpu.runtimeGetVersion()
@@ -206,3 +206,40 @@ async def test_cluster_worker():
await new_worker
await client.wait_for_workers(2)
await new_worker.close()
+
+
+@gen_test(timeout=20)
+async def test_available_mig_workers():
+ import dask
+
+ init_nvmlstatus = os.environ.get("DASK_DISTRIBUTED__DIAGNOSTICS__NVML")
+ try:
+ os.environ["DASK_DISTRIBUTED__DIAGNOSTICS__NVML"] = "False"
+ dask.config.refresh()
+ uuids = get_gpu_count_mig(return_uuids=True)[1]
+ if len(uuids) > 0:
+ uuids = [i.decode("utf-8") for i in uuids]
+ else:
+ pytest.skip("No MIG devices found")
+ CUDA_VISIBLE_DEVICES = ",".join(uuids)
+ os.environ["CUDA_VISIBLE_DEVICES"] = CUDA_VISIBLE_DEVICES
+ async with LocalCUDACluster(
+ CUDA_VISIBLE_DEVICES=CUDA_VISIBLE_DEVICES, asynchronous=True
+ ) as cluster:
+ async with Client(cluster, asynchronous=True) as client:
+ len(cluster.workers) == len(uuids)
+
+ # Check to see if CUDA_VISIBLE_DEVICES cycles properly
+ def get_visible_devices():
+ return os.environ["CUDA_VISIBLE_DEVICES"]
+
+ result = await client.run(get_visible_devices)
+
+ assert all(len(v.split(",")) == len(uuids) for v in result.values())
+ for i in range(len(uuids)):
+ assert set(v.split(",")[i] for v in result.values()) == set(uuids)
+ finally:
+ if "CUDA_VISIBLE_DEVICES" in os.environ:
+ del os.environ["CUDA_VISIBLE_DEVICES"]
+ if init_nvmlstatus:
+ os.environ["DASK_DISTRIBUTED__DIAGNOSTICS__NVML"] = init_nvmlstatus
diff --git a/dask_cuda/tests/test_proxify_host_file.py b/dask_cuda/tests/test_proxify_host_file.py
index 822e20fae..02094bece 100644
--- a/dask_cuda/tests/test_proxify_host_file.py
+++ b/dask_cuda/tests/test_proxify_host_file.py
@@ -1,3 +1,5 @@
+from typing import Iterable
+
import numpy as np
import pytest
from pandas.testing import assert_frame_equal
@@ -5,13 +7,16 @@
import dask
import dask.dataframe
from dask.dataframe.shuffle import shuffle_group
+from dask.sizeof import sizeof
from distributed import Client
+from distributed.client import wait
+from distributed.worker import get_worker
import dask_cuda
import dask_cuda.proxify_device_objects
-import dask_cuda.proxy_object
from dask_cuda.get_device_memory_objects import get_device_memory_objects
from dask_cuda.proxify_host_file import ProxifyHostFile
+from dask_cuda.proxy_object import ProxyObject, asproxy
cupy = pytest.importorskip("cupy")
cupy.cuda.set_allocator(None)
@@ -24,53 +29,146 @@
dask_cuda.proxify_device_objects.ignore_types = ()
-def test_one_item_limit():
- dhf = ProxifyHostFile(device_memory_limit=one_item_nbytes)
- dhf["k1"] = one_item_array() + 42
- dhf["k2"] = one_item_array()
+def is_proxies_equal(p1: Iterable[ProxyObject], p2: Iterable[ProxyObject]):
+ """Check that two collections of proxies contains the same proxies (unordered)
+
+ In order to avoid deserializing proxy objects when comparing them,
+ this funcntion compares object IDs.
+ """
+
+ ids1 = sorted([id(p) for p in p1])
+ ids2 = sorted([id(p) for p in p2])
+ return ids1 == ids2
+
+
+def test_one_dev_item_limit():
+ dhf = ProxifyHostFile(device_memory_limit=one_item_nbytes, memory_limit=1000)
+
+ a1 = one_item_array() + 42
+ a2 = one_item_array()
+ dhf["k1"] = a1
+ dhf["k2"] = a2
+ dhf.manager.validate()
# Check k1 is spilled because of the newer k2
k1 = dhf["k1"]
k2 = dhf["k2"]
assert k1._obj_pxy_is_serialized()
assert not k2._obj_pxy_is_serialized()
+ dhf.manager.validate()
+ assert is_proxies_equal(dhf.manager._host, [k1])
+ assert is_proxies_equal(dhf.manager._dev, [k2])
# Accessing k1 spills k2 and unspill k1
k1_val = k1[0]
assert k1_val == 42
assert k2._obj_pxy_is_serialized()
+ dhf.manager.validate()
+ assert is_proxies_equal(dhf.manager._host, [k2])
+ assert is_proxies_equal(dhf.manager._dev, [k1])
# Duplicate arrays changes nothing
dhf["k3"] = [k1, k2]
assert not k1._obj_pxy_is_serialized()
assert k2._obj_pxy_is_serialized()
+ dhf.manager.validate()
+ assert is_proxies_equal(dhf.manager._host, [k2])
+ assert is_proxies_equal(dhf.manager._dev, [k1])
# Adding a new array spills k1 and k2
dhf["k4"] = one_item_array()
+ k4 = dhf["k4"]
assert k1._obj_pxy_is_serialized()
assert k2._obj_pxy_is_serialized()
assert not dhf["k4"]._obj_pxy_is_serialized()
+ dhf.manager.validate()
+ assert is_proxies_equal(dhf.manager._host, [k1, k2])
+ assert is_proxies_equal(dhf.manager._dev, [k4])
# Accessing k2 spills k1 and k4
k2[0]
assert k1._obj_pxy_is_serialized()
assert dhf["k4"]._obj_pxy_is_serialized()
assert not k2._obj_pxy_is_serialized()
+ dhf.manager.validate()
+ assert is_proxies_equal(dhf.manager._host, [k1, k4])
+ assert is_proxies_equal(dhf.manager._dev, [k2])
# Deleting k2 does not change anything since k3 still holds a
# reference to the underlying proxy object
- assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes
- p1 = list(dhf.proxies_tally.get_unspilled_proxies())
- assert len(p1) == 1
+ assert dhf.manager.get_dev_access_info()[0] == one_item_nbytes
+ dhf.manager.validate()
+ assert is_proxies_equal(dhf.manager._host, [k1, k4])
+ assert is_proxies_equal(dhf.manager._dev, [k2])
del dhf["k2"]
- assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes
- p2 = list(dhf.proxies_tally.get_unspilled_proxies())
- assert len(p2) == 1
- assert p1[0] is p2[0]
+ dhf.manager.validate()
+ assert is_proxies_equal(dhf.manager._host, [k1, k4])
+ assert is_proxies_equal(dhf.manager._dev, [k2])
- # Overwriting "k3" with a non-cuda object, should be noticed
+ # Overwriting "k3" with a non-cuda object and deleting `k2`
+ # should empty the device
dhf["k3"] = "non-cuda-object"
- assert dhf.proxies_tally.get_dev_mem_usage() == 0
+ del k2
+ dhf.manager.validate()
+ assert is_proxies_equal(dhf.manager._host, [k1, k4])
+ assert is_proxies_equal(dhf.manager._dev, [])
+
+
+def test_one_item_host_limit():
+ memory_limit = sizeof(asproxy(one_item_array(), serializers=("dask", "pickle")))
+ dhf = ProxifyHostFile(
+ device_memory_limit=one_item_nbytes, memory_limit=memory_limit
+ )
+
+ a1 = one_item_array() + 1
+ a2 = one_item_array() + 2
+ dhf["k1"] = a1
+ dhf["k2"] = a2
+ dhf.manager.validate()
+
+ # Check k1 is spilled because of the newer k2
+ k1 = dhf["k1"]
+ k2 = dhf["k2"]
+ assert k1._obj_pxy_is_serialized()
+ assert not k2._obj_pxy_is_serialized()
+ dhf.manager.validate()
+ assert is_proxies_equal(dhf.manager._disk, [])
+ assert is_proxies_equal(dhf.manager._host, [k1])
+ assert is_proxies_equal(dhf.manager._dev, [k2])
+
+ # Check k1 is spilled to disk and k2 is spilled to host
+ dhf["k3"] = one_item_array() + 3
+ k3 = dhf["k3"]
+ dhf.manager.validate()
+ assert is_proxies_equal(dhf.manager._disk, [k1])
+ assert is_proxies_equal(dhf.manager._host, [k2])
+ assert is_proxies_equal(dhf.manager._dev, [k3])
+
+ dhf.manager.validate()
+
+ # Accessing k2 spills k3 and unspill k2
+ k2_val = k2[0]
+ assert k2_val == 2
+ dhf.manager.validate()
+ assert is_proxies_equal(dhf.manager._disk, [k1])
+ assert is_proxies_equal(dhf.manager._host, [k3])
+ assert is_proxies_equal(dhf.manager._dev, [k2])
+
+ # Adding a new array spill k3 to disk and k2 to host
+ dhf["k4"] = one_item_array() + 4
+ k4 = dhf["k4"]
+ dhf.manager.validate()
+ assert is_proxies_equal(dhf.manager._disk, [k1, k3])
+ assert is_proxies_equal(dhf.manager._host, [k2])
+ assert is_proxies_equal(dhf.manager._dev, [k4])
+
+ # Accessing k1 unspills k1 directly to device and spills k4 to host
+ k1_val = k1[0]
+ assert k1_val == 1
+ dhf.manager.validate()
+ assert is_proxies_equal(dhf.manager._disk, [k2, k3])
+ assert is_proxies_equal(dhf.manager._host, [k4])
+ assert is_proxies_equal(dhf.manager._dev, [k1])
@pytest.mark.parametrize("jit_unspill", [True, False])
@@ -84,7 +182,7 @@ def task(x):
if jit_unspill:
# Check that `x` is a proxy object and the proxied DataFrame is serialized
assert "FrameProxyObject" in str(type(x))
- assert x._obj_pxy["serializers"] == ("dask", "pickle")
+ assert x._obj_pxy["serializer"] == "dask"
else:
assert type(x) == cudf.DataFrame
assert len(x) == 10 # Trigger deserialization
@@ -114,7 +212,7 @@ def test_dataframes_share_dev_mem():
# They still share the same underlying device memory
assert view1["a"].data._owner._owner is view2["a"].data._owner._owner
- dhf = ProxifyHostFile(device_memory_limit=160)
+ dhf = ProxifyHostFile(device_memory_limit=160, memory_limit=1000)
dhf["v1"] = view1
dhf["v2"] = view2
v1 = dhf["v1"]
@@ -141,59 +239,49 @@ def test_cudf_get_device_memory_objects():
def test_externals():
- dhf = ProxifyHostFile(device_memory_limit=one_item_nbytes)
+ """Test adding objects directly to the manager
+
+ Add an object directly to the manager makes it count against the
+ device_memory_limit but isn't part of the store.
+
+ Normally, we use __setitem__ to store objects in the hostfile and make it
+ count against the device_memory_limit with the inherent consequence that
+ the objects are not freeable before subsequential calls to __delitem__.
+ This is a problem for long running tasks that want objects to count against
+ the device_memory_limit while freeing them ASAP without explicit calls to
+ __delitem__.
+ """
+ dhf = ProxifyHostFile(device_memory_limit=one_item_nbytes, memory_limit=1000)
dhf["k1"] = one_item_array()
k1 = dhf["k1"]
- k2 = dhf.add_external(one_item_array())
+ k2 = dhf.manager.proxify(one_item_array())
# `k2` isn't part of the store but still triggers spilling of `k1`
assert len(dhf) == 1
assert k1._obj_pxy_is_serialized()
assert not k2._obj_pxy_is_serialized()
+ assert is_proxies_equal(dhf.manager._host, [k1])
+ assert is_proxies_equal(dhf.manager._dev, [k2])
+ assert dhf.manager._dev._mem_usage == one_item_nbytes
+
k1[0] # Trigger spilling of `k2`
assert not k1._obj_pxy_is_serialized()
assert k2._obj_pxy_is_serialized()
+ assert is_proxies_equal(dhf.manager._host, [k2])
+ assert is_proxies_equal(dhf.manager._dev, [k1])
+ assert dhf.manager._dev._mem_usage == one_item_nbytes
+
k2[0] # Trigger spilling of `k1`
assert k1._obj_pxy_is_serialized()
assert not k2._obj_pxy_is_serialized()
- assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes
+ assert is_proxies_equal(dhf.manager._host, [k1])
+ assert is_proxies_equal(dhf.manager._dev, [k2])
+ assert dhf.manager._dev._mem_usage == one_item_nbytes
+
# Removing `k2` also removes it from the tally
del k2
- assert dhf.proxies_tally.get_dev_mem_usage() == 0
- assert len(list(dhf.proxies_tally.get_unspilled_proxies())) == 0
-
-
-def test_externals_setitem():
- dhf = ProxifyHostFile(device_memory_limit=one_item_nbytes)
- k1 = dhf.add_external(one_item_array())
- assert type(k1) is dask_cuda.proxy_object.ProxyObject
- assert len(dhf) == 0
- assert "external" in k1._obj_pxy
- assert "external_finalize" in k1._obj_pxy
- dhf["k1"] = k1
- k1 = dhf["k1"]
- assert type(k1) is dask_cuda.proxy_object.ProxyObject
- assert len(dhf) == 1
- assert "external" not in k1._obj_pxy
- assert "external_finalize" not in k1._obj_pxy
-
- k1 = dhf.add_external(one_item_array())
- k1._obj_pxy_serialize(serializers=("dask", "pickle"))
- dhf["k1"] = k1
- k1 = dhf["k1"]
- assert type(k1) is dask_cuda.proxy_object.ProxyObject
- assert len(dhf) == 1
- assert "external" not in k1._obj_pxy
- assert "external_finalize" not in k1._obj_pxy
-
- dhf["k1"] = one_item_array()
- assert len(dhf.proxies_tally.proxy_id_to_proxy) == 1
- assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes
- k1 = dhf.add_external(k1)
- assert len(dhf.proxies_tally.proxy_id_to_proxy) == 1
- assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes
- k1 = dhf.add_external(dhf["k1"])
- assert len(dhf.proxies_tally.proxy_id_to_proxy) == 1
- assert dhf.proxies_tally.get_dev_mem_usage() == one_item_nbytes
+ assert is_proxies_equal(dhf.manager._host, [k1])
+ assert is_proxies_equal(dhf.manager._dev, [])
+ assert dhf.manager._dev._mem_usage == 0
def test_proxify_device_objects_of_cupy_array():
@@ -247,3 +335,41 @@ def is_proxy_object(x):
assert not any(res) # No proxy objects
else:
assert all(res) # Only proxy objects
+
+
+def test_worker_force_spill_to_disk():
+ """ Test Dask triggering CPU-to-Disk spilling """
+ cudf = pytest.importorskip("cudf")
+
+ with dask.config.set({"distributed.worker.memory.terminate": 0}):
+ with dask_cuda.LocalCUDACluster(
+ n_workers=1, device_memory_limit="1MB", jit_unspill=True
+ ) as cluster:
+ with Client(cluster) as client:
+ # Create a df that are spilled to host memory immediately
+ df = cudf.DataFrame({"key": np.arange(10 ** 8)})
+ ddf = dask.dataframe.from_pandas(df, npartitions=1).persist()
+ wait(ddf)
+
+ def f():
+ """Trigger a memory_monitor() and reset memory_limit"""
+ w = get_worker()
+
+ async def y():
+ # Set a host memory limit that triggers spilling to disk
+ w.memory_pause_fraction = False
+ memory = w.monitor.proc.memory_info().rss
+ w.memory_limit = memory - 10 ** 8
+ w.memory_target_fraction = 1
+ await w.memory_monitor()
+ # Check that host memory are freed
+ assert w.monitor.proc.memory_info().rss < memory - 10 ** 7
+ w.memory_limit = memory * 10 # Un-limit
+
+ w.loop.add_callback(y)
+
+ wait(client.submit(f))
+ # Check that the worker doesn't complain about unmanaged memory
+ assert "Unmanaged memory use is high" not in str(
+ client.get_worker_logs()
+ )
diff --git a/dask_cuda/tests/test_proxy.py b/dask_cuda/tests/test_proxy.py
index 6d3f1c972..4b87e09fa 100644
--- a/dask_cuda/tests/test_proxy.py
+++ b/dask_cuda/tests/test_proxy.py
@@ -2,9 +2,10 @@
import pickle
from types import SimpleNamespace
+import numpy as np
import pandas
import pytest
-from pandas.testing import assert_frame_equal
+from pandas.testing import assert_frame_equal, assert_series_equal
import dask
import dask.array
@@ -17,48 +18,71 @@
import dask_cuda
from dask_cuda import proxy_object
from dask_cuda.proxify_device_objects import proxify_device_objects
+from dask_cuda.proxify_host_file import ProxifyHostFile
+ProxifyHostFile.register_disk_spilling() # Make the "disk" serializer available
-@pytest.mark.parametrize("serializers", [None, ("dask", "pickle")])
+
+@pytest.mark.parametrize("serializers", [None, ("dask", "pickle"), ("disk",)])
def test_proxy_object(serializers):
"""Check "transparency" of the proxy object"""
- org = list(range(10))
+ org = bytearray(range(10))
pxy = proxy_object.asproxy(org, serializers=serializers)
assert len(org) == len(pxy)
assert org[0] == pxy[0]
assert 1 in pxy
- assert -1 not in pxy
+ assert 10 not in pxy
assert str(org) == str(pxy)
assert "dask_cuda.proxy_object.ProxyObject at " in repr(pxy)
- assert "list at " in repr(pxy)
+ assert "bytearray at " in repr(pxy)
pxy._obj_pxy_serialize(serializers=("dask", "pickle"))
assert "dask_cuda.proxy_object.ProxyObject at " in repr(pxy)
- assert "list (serialized=('dask', 'pickle'))" in repr(pxy)
+ assert "bytearray (serialized='dask')" in repr(pxy)
assert org == proxy_object.unproxy(pxy)
assert org == proxy_object.unproxy(org)
-@pytest.mark.parametrize("serializers_first", [None, ("dask", "pickle")])
-@pytest.mark.parametrize("serializers_second", [None, ("dask", "pickle")])
+class DummyObj:
+ """Class that only "pickle" can serialize"""
+
+ def __reduce__(self):
+ return (DummyObj, ())
+
+
+def test_proxy_object_serializer():
+ """Check the serializers argument"""
+ pxy = proxy_object.asproxy(DummyObj(), serializers=("dask", "pickle"))
+ assert pxy._obj_pxy["serializer"] == "pickle"
+ assert "DummyObj (serialized='pickle')" in repr(pxy)
+
+ with pytest.raises(ValueError) as excinfo:
+ pxy = proxy_object.asproxy([42], serializers=("dask", "pickle"))
+ assert "Cannot wrap a collection" in str(excinfo.value)
+
+
+@pytest.mark.parametrize("serializers_first", [None, ("dask", "pickle"), ("disk",)])
+@pytest.mark.parametrize("serializers_second", [None, ("dask", "pickle"), ("disk",)])
def test_double_proxy_object(serializers_first, serializers_second):
"""Check asproxy() when creating a proxy object of a proxy object"""
- org = list(range(10))
+ serializer1 = serializers_first[0] if serializers_first else None
+ serializer2 = serializers_second[0] if serializers_second else None
+ org = bytearray(range(10))
pxy1 = proxy_object.asproxy(org, serializers=serializers_first)
- assert pxy1._obj_pxy["serializers"] == serializers_first
+ assert pxy1._obj_pxy["serializer"] == serializer1
pxy2 = proxy_object.asproxy(pxy1, serializers=serializers_second)
if serializers_second is None:
# Check that `serializers=None` doesn't change the initial serializers
- assert pxy2._obj_pxy["serializers"] == serializers_first
+ assert pxy2._obj_pxy["serializer"] == serializer1
else:
- assert pxy2._obj_pxy["serializers"] == serializers_second
+ assert pxy2._obj_pxy["serializer"] == serializer2
assert pxy1 is pxy2
-@pytest.mark.parametrize("serializers", [None, ("dask", "pickle")])
+@pytest.mark.parametrize("serializers", [None, ("dask", "pickle"), ("disk",)])
@pytest.mark.parametrize("backend", ["numpy", "cupy"])
def test_proxy_object_of_array(serializers, backend):
"""Check that a proxied array behaves as a regular (numpy or cupy) array"""
@@ -180,7 +204,7 @@ def test_proxy_object_of_array(serializers, backend):
assert all(expect == got)
-@pytest.mark.parametrize("serializers", [None, ["dask"]])
+@pytest.mark.parametrize("serializers", [None, ["dask"], ["disk"]])
def test_proxy_object_of_cudf(serializers):
"""Check that a proxied cudf dataframe behaves as a regular dataframe"""
cudf = pytest.importorskip("cudf")
@@ -189,14 +213,13 @@ def test_proxy_object_of_cudf(serializers):
assert_frame_equal(df.to_pandas(), pxy.to_pandas())
-@pytest.mark.parametrize("proxy_serializers", [None, ["dask"], ["cuda"]])
+@pytest.mark.parametrize("proxy_serializers", [None, ["dask"], ["cuda"], ["disk"]])
@pytest.mark.parametrize("dask_serializers", [["dask"], ["cuda"]])
def test_serialize_of_proxied_cudf(proxy_serializers, dask_serializers):
"""Check that we can serialize a proxied cudf dataframe, which might
be serialized already.
"""
cudf = pytest.importorskip("cudf")
-
df = cudf.DataFrame({"a": range(10)})
pxy = proxy_object.asproxy(df, serializers=proxy_serializers)
header, frames = serialize(pxy, serializers=dask_serializers, on_error="raise")
@@ -257,7 +280,7 @@ def task(x):
if jit_unspill:
# Check that `x` is a proxy object and the proxied DataFrame is serialized
assert "FrameProxyObject" in str(type(x))
- assert x._obj_pxy["serializers"] == ("dask", "pickle")
+ assert x._obj_pxy["serializer"] == "dask"
else:
assert type(x) == cudf.DataFrame
assert len(x) == 10 # Trigger deserialization
@@ -279,6 +302,45 @@ def task(x):
assert_frame_equal(got.to_pandas(), df.to_pandas())
+@pytest.mark.parametrize("obj", [bytearray(10), bytearray(10 ** 6)])
+def test_serializing_to_disk(obj):
+ """Check serializing to disk"""
+
+ if isinstance(obj, str):
+ backend = pytest.importorskip(obj)
+ obj = backend.arange(100)
+
+ # Serialize from device to disk
+ pxy = proxy_object.asproxy(obj)
+ ProxifyHostFile.serialize_proxy_to_disk_inplace(pxy)
+ assert pxy._obj_pxy["serializer"] == "disk"
+ assert obj == proxy_object.unproxy(pxy)
+
+ # Serialize from host to disk
+ pxy = proxy_object.asproxy(obj, serializers=("pickle",))
+ ProxifyHostFile.serialize_proxy_to_disk_inplace(pxy)
+ assert pxy._obj_pxy["serializer"] == "disk"
+ assert obj == proxy_object.unproxy(pxy)
+
+
+@pytest.mark.parametrize("size", [10, 10 ** 4])
+@pytest.mark.parametrize(
+ "serializers", [None, ["dask"], ["cuda", "dask"], ["pickle"], ["disk"]]
+)
+@pytest.mark.parametrize("backend", ["numpy", "cupy"])
+def test_serializing_array_to_disk(backend, serializers, size):
+ """Check serializing arrays to disk"""
+
+ np = pytest.importorskip(backend)
+ obj = np.arange(size)
+
+ # Serialize from host to disk
+ pxy = proxy_object.asproxy(obj, serializers=serializers)
+ ProxifyHostFile.serialize_proxy_to_disk_inplace(pxy)
+ assert pxy._obj_pxy["serializer"] == "disk"
+ assert list(obj) == list(proxy_object.unproxy(pxy))
+
+
class _PxyObjTest(proxy_object.ProxyObject):
"""
A class that:
@@ -292,7 +354,7 @@ def __dask_tokenize__(self):
def _obj_pxy_deserialize(self):
if self._obj_pxy["assert_on_deserializing"]:
- assert self._obj_pxy["serializers"] is None
+ assert self._obj_pxy["serializer"] is None
return super()._obj_pxy_deserialize()
@@ -305,16 +367,16 @@ def test_communicating_proxy_objects(protocol, send_serializers):
def task(x):
# Check that the subclass survives the trip from client to worker
assert isinstance(x, _PxyObjTest)
- serializers_used = x._obj_pxy["serializers"]
+ serializers_used = x._obj_pxy["serializer"]
# Check that `x` is serialized with the expected serializers
if protocol == "ucx":
if send_serializers is None:
- assert serializers_used == ("cuda",)
+ assert serializers_used == "cuda"
else:
- assert serializers_used == send_serializers
+ assert serializers_used == send_serializers[0]
else:
- assert serializers_used == ("dask", "pickle")
+ assert serializers_used == "dask"
with dask_cuda.LocalCUDACluster(
n_workers=1, protocol=protocol, enable_tcp_over_ucx=protocol == "ucx"
@@ -337,9 +399,37 @@ def task(x):
client.shutdown() # Avoids a UCX shutdown error
+@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
+@pytest.mark.parametrize("shared_fs", [True, False])
+def test_communicating_disk_objects(protocol, shared_fs):
+ """Testing disk serialization of cuDF dataframe when communicating"""
+ cudf = pytest.importorskip("cudf")
+ ProxifyHostFile._spill_shared_filesystem = shared_fs
+
+ def task(x):
+ # Check that the subclass survives the trip from client to worker
+ assert isinstance(x, _PxyObjTest)
+ serializer_used = x._obj_pxy["serializer"]
+ if shared_fs:
+ assert serializer_used == "disk"
+ else:
+ assert serializer_used == "dask"
+
+ with dask_cuda.LocalCUDACluster(
+ n_workers=1, protocol=protocol, enable_tcp_over_ucx=protocol == "ucx"
+ ) as cluster:
+ with Client(cluster) as client:
+ df = cudf.DataFrame({"a": range(10)})
+ df = proxy_object.asproxy(df, serializers=("disk",), subclass=_PxyObjTest)
+ df._obj_pxy["assert_on_deserializing"] = False
+ df = client.scatter(df)
+ client.submit(task, df).result()
+ client.shutdown() # Avoids a UCX shutdown error
+
+
@pytest.mark.parametrize("array_module", ["numpy", "cupy"])
@pytest.mark.parametrize(
- "serializers", [None, ("dask", "pickle"), ("cuda", "dask", "pickle")]
+ "serializers", [None, ("dask", "pickle"), ("cuda", "dask", "pickle"), ("disk",)]
)
def test_pickle_proxy_object(array_module, serializers):
"""Check pickle of the proxy object"""
@@ -458,3 +548,18 @@ def test_merge_sorted_of_proxied_cudf_dataframes():
got = cudf.merge_sorted(proxify_device_objects(dfs, {}, []))
expected = cudf.merge_sorted(dfs)
assert_frame_equal(got.to_pandas(), expected.to_pandas())
+
+
+@pytest.mark.parametrize(
+ "np_func", [np.less, np.less_equal, np.greater, np.greater_equal, np.equal]
+)
+def test_array_ufucn_proxified_object(np_func):
+ cudf = pytest.importorskip("cudf")
+
+ np_array = np.array(100)
+ ser = cudf.Series([1, 2, 3])
+ proxy_obj = proxify_device_objects(ser)
+ expected = np_func(ser, np_array)
+ actual = np_func(proxy_obj, np_array)
+
+ assert_series_equal(expected.to_pandas(), actual.to_pandas())
diff --git a/dask_cuda/tests/test_ucx_options.py b/dask_cuda/tests/test_ucx_options.py
deleted file mode 100644
index 77af9357b..000000000
--- a/dask_cuda/tests/test_ucx_options.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import multiprocessing as mp
-
-import numpy
-import pytest
-
-import dask
-from dask import array as da
-from distributed import Client
-from distributed.deploy.local import LocalCluster
-
-from dask_cuda.utils import _ucx_110
-
-mp = mp.get_context("spawn")
-ucp = pytest.importorskip("ucp")
-
-# Notice, all of the following tests is executed in a new process such
-# that UCX options of the different tests doesn't conflict.
-# Furthermore, all tests do some computation to trigger initialization
-# of UCX before retrieving the current config.
-
-
-def _test_global_option(seg_size):
- """Test setting UCX options through dask's global config"""
- tls = "tcp,cuda_copy" if _ucx_110 else "tcp,sockcm,cuda_copy"
- tls_priority = "tcp" if _ucx_110 else "sockcm"
- dask.config.update(
- dask.config.global_config,
- {
- "ucx": {
- "SEG_SIZE": seg_size,
- "TLS": tls,
- "SOCKADDR_TLS_PRIORITY": tls_priority,
- },
- },
- priority="new",
- )
-
- with LocalCluster(
- protocol="ucx",
- dashboard_address=None,
- n_workers=1,
- threads_per_worker=1,
- processes=True,
- ) as cluster:
- with Client(cluster):
- res = da.from_array(numpy.arange(10000), chunks=(1000,))
- res = res.sum().compute()
- assert res == 49995000
- conf = ucp.get_config()
- assert conf["SEG_SIZE"] == seg_size
-
-
-@pytest.mark.xfail(reason="https://github.com/rapidsai/dask-cuda/issues/627")
-def test_global_option():
- for seg_size in ["2K", "1M", "2M"]:
- p = mp.Process(target=_test_global_option, args=(seg_size,))
- p.start()
- p.join()
- assert not p.exitcode
diff --git a/dask_cuda/tests/test_utils.py b/dask_cuda/tests/test_utils.py
index edfb04623..c6838c323 100644
--- a/dask_cuda/tests/test_utils.py
+++ b/dask_cuda/tests/test_utils.py
@@ -249,3 +249,32 @@ def test_parse_device_memory_limit():
assert parse_device_memory_limit(0.8) == int(total * 0.8)
assert parse_device_memory_limit(1000000000) == 1000000000
assert parse_device_memory_limit("1GB") == 1000000000
+
+
+def test_parse_visible_mig_devices():
+ pynvml = pytest.importorskip("pynvml")
+ pynvml.nvmlInit()
+ for index in range(get_gpu_count()):
+ handle = pynvml.nvmlDeviceGetHandleByIndex(index)
+ try:
+ mode = pynvml.nvmlDeviceGetMigMode(handle)[0]
+ except pynvml.NVMLError:
+ # if not a MIG device, i.e. a normal GPU, skip
+ continue
+ if mode:
+ # Just checks to see if there are any MIG enabled GPUS.
+ # If there is one, check if the number of mig instances
+ # in that GPU is <= to count, where count gives us the
+ # maximum number of MIG devices/instances that can exist
+ # under a given parent NVML device.
+ count = pynvml.nvmlDeviceGetMaxMigDeviceCount(handle)
+ miguuids = []
+ for i in range(count):
+ try:
+ mighandle = pynvml.nvmlDeviceGetMigDeviceHandleByIndex(
+ device=handle, index=i
+ )
+ miguuids.append(mighandle)
+ except pynvml.NVMLError:
+ pass
+ assert len(miguuids) <= count
diff --git a/dask_cuda/utils.py b/dask_cuda/utils.py
index 171af01a8..457306bcf 100644
--- a/dask_cuda/utils.py
+++ b/dask_cuda/utils.py
@@ -130,13 +130,50 @@ def get_gpu_count():
return pynvml.nvmlDeviceGetCount()
-def get_cpu_affinity(device_index):
+@toolz.memoize
+def get_gpu_count_mig(return_uuids=False):
+ """Return the number of MIG instances available
+
+ Parameters
+ ----------
+ return_uuids: bool
+ Returns the uuids of the MIG instances available optionally
+
+ """
+ pynvml.nvmlInit()
+ uuids = []
+ for index in range(get_gpu_count()):
+ handle = pynvml.nvmlDeviceGetHandleByIndex(index)
+ try:
+ is_mig_mode = pynvml.nvmlDeviceGetMigMode(handle)[0]
+ except pynvml.NVMLError:
+ # if not a MIG device, i.e. a normal GPU, skip
+ continue
+ if is_mig_mode:
+ count = pynvml.nvmlDeviceGetMaxMigDeviceCount(handle)
+ miguuids = []
+ for i in range(count):
+ try:
+ mighandle = pynvml.nvmlDeviceGetMigDeviceHandleByIndex(
+ device=handle, index=i
+ )
+ miguuids.append(mighandle)
+ uuids.append(pynvml.nvmlDeviceGetUUID(mighandle))
+ except pynvml.NVMLError:
+ pass
+ if return_uuids:
+ return len(uuids), uuids
+ return len(uuids)
+
+
+def get_cpu_affinity(device_index=None):
"""Get a list containing the CPU indices to which a GPU is directly connected.
+ Use either the device index or the specified device identifier UUID.
Parameters
----------
- device_index: int
- Index of the GPU device
+ device_index: int or str
+ Index or UUID of the GPU device
Examples
--------
@@ -158,10 +195,19 @@ def get_cpu_affinity(device_index):
pynvml.nvmlInit()
try:
+ if device_index and not str(device_index).isnumeric():
+ # This means device_index is UUID.
+ # This works for both MIG and non-MIG device UUIDs.
+ handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(device_index))
+ if pynvml.nvmlDeviceIsMigDeviceHandle(handle):
+ # Additionally get parent device handle
+ # if the device itself is a MIG instance
+ handle = pynvml.nvmlDeviceGetDeviceHandleFromMigDeviceHandle(handle)
+ else:
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
# Result is a list of 64-bit integers, thus ceil(get_cpu_count() / 64)
affinity = pynvml.nvmlDeviceGetCpuAffinity(
- pynvml.nvmlDeviceGetHandleByIndex(device_index),
- math.ceil(get_cpu_count() / 64),
+ handle, math.ceil(get_cpu_count() / 64),
)
return unpack_bitmask(affinity)
except pynvml.NVMLError:
@@ -181,12 +227,17 @@ def get_n_gpus():
def get_device_total_memory(index=0):
"""
- Return total memory of CUDA device with index
+ Return total memory of CUDA device with index or with device identifier UUID
"""
pynvml.nvmlInit()
- return pynvml.nvmlDeviceGetMemoryInfo(
- pynvml.nvmlDeviceGetHandleByIndex(index)
- ).total
+
+ if index and not str(index).isnumeric():
+ # This means index is UUID. This works for both MIG and non-MIG device UUIDs.
+ handle = pynvml.nvmlDeviceGetHandleByUUID(str.encode(str(index)))
+ else:
+ # This is a device index
+ handle = pynvml.nvmlDeviceGetHandleByIndex(index)
+ return pynvml.nvmlDeviceGetMemoryInfo(handle).total
def get_ucx_net_devices(
@@ -464,12 +515,13 @@ def parse_cuda_visible_device(dev):
try:
return int(dev)
except ValueError:
- if any(dev.startswith(prefix) for prefix in ["GPU-", "MIG-GPU-"]):
+ if any(dev.startswith(prefix) for prefix in ["GPU-", "MIG-GPU-", "MIG-"]):
return dev
else:
raise ValueError(
"Devices in CUDA_VISIBLE_DEVICES must be comma-separated integers "
- "or strings beginning with 'GPU-' or 'MIG-GPU-' prefixes."
+ "or strings beginning with 'GPU-' or 'MIG-GPU-' prefixes"
+ " or 'MIG-'."
)
@@ -514,13 +566,25 @@ def nvml_device_index(i, CUDA_VISIBLE_DEVICES):
1
>>> nvml_device_index(1, [1,2,3,0])
2
+ >>> nvml_device_index(1, ["GPU-84fd49f2-48ad-50e8-9f2e-3bf0dfd47ccb",
+ "GPU-d6ac2d46-159b-5895-a854-cb745962ef0f",
+ "GPU-158153b7-51d0-5908-a67c-f406bc86be17"])
+ "MIG-d6ac2d46-159b-5895-a854-cb745962ef0f"
+ >>> nvml_device_index(2, ["MIG-41b3359c-e721-56e5-8009-12e5797ed514",
+ "MIG-65b79fff-6d3c-5490-a288-b31ec705f310",
+ "MIG-c6e2bae8-46d4-5a7e-9a68-c6cf1f680ba0"])
+ "MIG-c6e2bae8-46d4-5a7e-9a68-c6cf1f680ba0"
>>> nvml_device_index(1, 2)
Traceback (most recent call last):
...
ValueError: CUDA_VISIBLE_DEVICES must be `str` or `list`
"""
if isinstance(CUDA_VISIBLE_DEVICES, str):
- return int(CUDA_VISIBLE_DEVICES.split(",")[i])
+ ith_elem = CUDA_VISIBLE_DEVICES.split(",")[i]
+ if ith_elem.isnumeric():
+ return int(ith_elem)
+ else:
+ return ith_elem
elif isinstance(CUDA_VISIBLE_DEVICES, list):
return CUDA_VISIBLE_DEVICES[i]
else:
@@ -530,15 +594,15 @@ def nvml_device_index(i, CUDA_VISIBLE_DEVICES):
def parse_device_memory_limit(device_memory_limit, device_index=0):
"""Parse memory limit to be used by a CUDA device.
-
Parameters
----------
device_memory_limit: float, int, str or None
This can be a float (fraction of total device memory), an integer (bytes),
a string (like 5GB or 5000M), and "auto", 0 or None for the total device
size.
- device_index: int
- The index of device from which to obtain the total memory amount.
+ device_index: int or str
+ The index or UUID of the device from which to obtain the total memory amount.
+ Default: 0.
Examples
--------
diff --git a/docs/source/examples/ucx.rst b/docs/source/examples/ucx.rst
index 77b12ce65..036b99291 100644
--- a/docs/source/examples/ucx.rst
+++ b/docs/source/examples/ucx.rst
@@ -22,11 +22,13 @@ To connect a client to a cluster with all supported transports and an RMM pool:
enable_nvlink=True,
enable_infiniband=True,
enable_rdmacm=True,
- ucx_net_devices="auto",
rmm_pool_size="1GB"
)
client = Client(cluster)
+.. note::
+ For UCX 1.9 (deprecated) and older, it's necessary to pass ``ucx_net_devices="auto"`` to ``LocalCUDACluster``. UCX 1.11 and above is capable of selecting InfiniBand devices automatically.
+
dask-cuda-worker
----------------
@@ -41,18 +43,19 @@ To start a Dask scheduler using UCX with all supported transports and an gigabyt
.. code-block:: bash
- $ DASK_UCX__CUDA_COPY=True \
- > DASK_UCX__TCP=True \
- > DASK_UCX__NVLINK=True \
- > DASK_UCX__INFINIBAND=True \
- > DASK_UCX__RDMACM=True \
- > DASK_UCX__NET_DEVICES=mlx5_0:1 \
- > DASK_RMM__POOL_SIZE=1GB \
+ $ DASK_DISTRIBUTED__COMM__UCX__CUDA_COPY=True \
+ > DASK_DISTRIBUTED__COMM__UCX__TCP=True \
+ > DASK_DISTRIBUTED__COMM__UCX__NVLINK=True \
+ > DASK_DISTRIBUTED__COMM__UCX__INFINIBAND=True \
+ > DASK_DISTRIBUTED__COMM__UCX__RDMACM=True \
+ > DASK_DISTRIBUTED__RMM__POOL_SIZE=1GB \
> dask-scheduler --protocol ucx --interface ib0
-Note the specification of ``"mlx5_0:1"`` as our UCX net device; because the scheduler does not rely upon Dask-CUDA, it cannot automatically detect InfiniBand interfaces, so we must specify one explicitly.
We communicate to the scheduler that we will be using UCX with the ``--protocol`` option, and that we will be using InfiniBand with the ``--interface`` option.
+.. note::
+ For UCX 1.9 (deprecated) and older it's also necessary to set ``DASK_DISTRIBUTED__COMM__UCX__NET_DEVICES=mlx5_0:1``, where ``"mlx5_0:1"`` is our UCX net device; because the scheduler does not rely upon Dask-CUDA, it cannot automatically detect InfiniBand interfaces, so we must specify one explicitly. UCX 1.11 and above is capable of selecting InfiniBand devices automatically.
+
Workers
^^^^^^^
@@ -66,9 +69,11 @@ To start a cluster with all supported transports and an RMM pool:
> --enable-nvlink \
> --enable-infiniband \
> --enable-rdmacm \
- > --net-devices="auto" \
> --rmm-pool-size="1GB"
+.. note::
+ For UCX 1.9 (deprecated) and older it's also necessary to set ``--net-devices="auto"``. UCX 1.11 and above is capable of selecting InfiniBand devices automatically.
+
Client
^^^^^^
@@ -85,8 +90,8 @@ To connect a client to the cluster we have made:
enable_nvlink=True,
enable_infiniband=True,
enable_rdmacm=True,
- net_devices="mlx5_0:1",
)
client = Client("ucx://:8786")
-Note again the specification of ``"mlx5_0:1"`` as our UCX net device, due to the fact that the client does not support automatic detection of InfiniBand interfaces.
+.. note::
+ For UCX 1.9 (deprecated) and older it's also necessary to set ``net_devices="mlx5_0:1"``, where ``"mlx5_0:1"`` is our UCX net device; because the client does not rely upon Dask-CUDA, it cannot automatically detect InfiniBand interfaces, so we must specify one explicitly. UCX 1.11 and above is capable of selecting InfiniBand devices automatically.
diff --git a/docs/source/ucx.rst b/docs/source/ucx.rst
index 1bc262b93..4246f541a 100644
--- a/docs/source/ucx.rst
+++ b/docs/source/ucx.rst
@@ -27,30 +27,30 @@ In addition to installations of UCX and UCX-Py on your system, several options m
Typically, these will affect ``UCX_TLS`` and ``UCX_SOCKADDR_TLS_PRIORITY``, environment variables used by UCX to decide what transport methods to use and which to prioritize, respectively.
However, some will affect related libraries, such as RMM:
-- ``ucx.cuda_copy: true`` -- **required.**
+- ``distributed.comm.ucx.cuda_copy: true`` -- **required.**
Adds ``cuda_copy`` to ``UCX_TLS``, enabling CUDA transfers over UCX.
-- ``ucx.tcp: true`` -- **required.**
+- ``distributed.comm.ucx.tcp: true`` -- **required.**
Adds ``tcp`` to ``UCX_TLS``, enabling TCP transfers over UCX; this is required for very small transfers which are inefficient for NVLink and InfiniBand.
-- ``ucx.nvlink: true`` -- **required for NVLink.**
+- ``distributed.comm.ucx.nvlink: true`` -- **required for NVLink.**
Adds ``cuda_ipc`` to ``UCX_TLS``, enabling NVLink transfers over UCX; affects intra-node communication only.
-- ``ucx.infiniband: true`` -- **required for InfiniBand.**
+- ``distributed.comm.ucx.infiniband: true`` -- **required for InfiniBand.**
Adds ``rc`` to ``UCX_TLS``, enabling InfiniBand transfers over UCX.
For optimal performance with UCX 1.11 and above, it is recommended to also set the environment variables ``UCX_MAX_RNDV_RAILS=1`` and ``UCX_MEMTYPE_REG_WHOLE_ALLOC_TYPES=cuda``, see documentation `here `_ and `here `_ for more details on those variables.
-- ``ucx.rdmacm: true`` -- **recommended for InfiniBand.**
+- ``distributed.comm.ucx.rdmacm: true`` -- **recommended for InfiniBand.**
Replaces ``sockcm`` with ``rdmacm`` in ``UCX_SOCKADDR_TLS_PRIORITY``, enabling remote direct memory access (RDMA) for InfiniBand transfers.
This is recommended by UCX for use with InfiniBand, and will not work if InfiniBand is disabled.
-- ``ucx.net-devices: `` -- **recommended for UCX 1.9 and older.**
+- ``distributed.comm.ucx.net-devices: `` -- **recommended for UCX 1.9 and older.**
Explicitly sets ``UCX_NET_DEVICES`` instead of defaulting to ``"all"``, which can result in suboptimal performance.
If using InfiniBand, set to ``"auto"`` to automatically detect the InfiniBand interface closest to each GPU on UCX 1.9 and below.
@@ -65,14 +65,14 @@ However, some will affect related libraries, such as RMM:
-- ``rmm.pool-size: `` -- **recommended.**
+- ``distributed.rmm.pool-size: `` -- **recommended.**
Allocates an RMM pool of the specified size for the process; size can be provided with an integer number of bytes or in human readable format, e.g. ``"4GB"``.
It is recommended to set the pool size to at least the minimum amount of memory used by the process; if possible, one can map all GPU memory to a single pool, to be utilized for the lifetime of the process.
.. note::
These options can be used with mainline Dask.distributed.
- However, some features are exclusive to Dask-CUDA, such as the automatic detection of InfiniBand interfaces.
+ However, some features are exclusive to Dask-CUDA, such as the automatic detection of InfiniBand interfaces.
See `Dask-CUDA -- Motivation `_ for more details on the benefits of using Dask-CUDA.
Usage
diff --git a/examples/ucx/dask_cuda_worker.sh b/examples/ucx/dask_cuda_worker.sh
index 27d113fdf..f1ec98186 100644
--- a/examples/ucx/dask_cuda_worker.sh
+++ b/examples/ucx/dask_cuda_worker.sh
@@ -23,9 +23,9 @@ if [ -z ${interface+x} ] && ! [ -z ${transport+x} ]; then
fi
# set up environment variables/flags
-DASK_UCX__CUDA_COPY=True
-DASK_UCX__TCP=True
-DASK_RMM__POOL_SIZE=$rmm_pool_size
+DASK_DISTRIBUTED__COMM__UCX__CUDA_COPY=True
+DASK_DISTRIBUTED__COMM__UCX__TCP=True
+DASK_DISTRIBUTED__RMM__POOL_SIZE=$rmm_pool_size
scheduler_flags="--scheduler-file scheduler.json --protocol ucx"
worker_flags="--scheduler-file scheduler.json --enable-tcp-over-ucx --rmm-pool-size ${rmm_pool_size}"
@@ -34,17 +34,15 @@ if ! [ -z ${interface+x} ]; then
scheduler_flags+=" --interface ${interface}"
fi
if [[ $transport == *"nvlink"* ]]; then
- DASK_UCX__NVLINK=True
+ DASK_DISTRIBUTED__COMM__UCX__NVLINK=True
worker_flags+=" --enable-nvlink"
fi
if [[ $transport == *"ib"* ]]; then
- DASK_UCX__INFINIBAND=True
- # DASK_UCX__RDMACM=True # RDMACM not working right now
- DASK_UCX__NET_DEVICES=mlx5_0:1
+ DASK_DISTRIBUTED__COMM__UCX__INFINIBAND=True
+ DASK_DISTRIBUTED__COMM__UCX__RDMACM=True
- # worker_flags+=" --enable-infiniband --enable-rdmacm --net-devices=auto"
- worker_flags+=" --enable-infiniband --net-devices=auto"
+ worker_flags+=" --enable-infiniband --enable-rdmacm"
fi
# initialize scheduler
diff --git a/requirements.txt b/requirements.txt
index 3ddbedb45..3121411a1 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
-dask>=2.22.0,<=2021.07.1
-distributed>=2.22.0,<=2021.07.1
-pynvml>=8.0.3
+dask==2021.09.1
+distributed==2021.09.1
+pynvml>=11.0.0
numpy>=1.16.0
numba>=0.53.1