diff --git a/distributed/config.py b/distributed/config.py index b6e5f3d0d36..84de644c70f 100644 --- a/distributed/config.py +++ b/distributed/config.py @@ -57,6 +57,7 @@ "distributed.scheduler.events-log-length": "distributed.admin.low-level-log-length", "recent-messages-log-length": "distributed.admin.low-level-log-length", "distributed.comm.recent-messages-log-length": "distributed.admin.low-level-log-length", + "distributed.p2p.disk": "distributed.p2p.storage.disk", } # Affects yaml and env variables configs, as well as calls to dask.config.set() diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index 47e1610cde3..191dff05233 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -1045,6 +1045,14 @@ properties: description: Configuration settings for Dask communications specific to P2P properties: + buffer: + type: + - string + - integer + description: | + The maximum amount of data for P2P's comm buffers to buffer in-memory per worker. + + This limit is not absolute but used to apply back pressure. retry: type: object description: | @@ -1066,10 +1074,29 @@ properties: max: type: string description: The maximum delay between retries - disk: - type: boolean - description: | - Whether or not P2P stores intermediate data on disk instead of memory + storage: + type: object + description: Configuration settings for P2P storage + properties: + + buffer: + type: + - string + - integer + description: | + The maximum amount of data for P2P's storage buffers to buffer in-memory per worker + + This limit is not absolute but used to apply back pressure. + disk: + type: boolean + description: | + Whether or not P2P stores intermediate data on disk instead of memory + threads: + type: integer + description: Number of threads used for CPU-intensive operations per worker + io-threads: + type: integer + description: Number of threads used for I/O operations per worker dashboard: type: object diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index ec1a48aa9ed..4b8ca7f9d8c 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -301,12 +301,19 @@ distributed: p2p: comm: + buffer: 500 MiB + message-bytes-limit: 50 MiB retry: count: 10 delay: min: 1s # the first non-zero delay between re-tries max: 30s # the maximum delay between re-tries - disk: True + concurrency: 4 + storage: + buffer: 500 MiB + disk: True + concurrency: 1 + threads: 1 ################### # Bokeh dashboard # diff --git a/distributed/shuffle/_arrow.py b/distributed/shuffle/_arrow.py index 2eee7c837fc..bee73ff6f87 100644 --- a/distributed/shuffle/_arrow.py +++ b/distributed/shuffle/_arrow.py @@ -6,8 +6,6 @@ from packaging.version import parse -from dask.utils import parse_bytes - if TYPE_CHECKING: import pandas as pd import pyarrow as pa @@ -46,6 +44,27 @@ def check_minimal_arrow_version() -> None: ) +def combine_tables(tables: Iterable[pa.Table], deep_copy: bool = True) -> pa.Table: + import numpy as np + + table = concat_tables(tables) + # if table. + # if deep_copy: + # return copy_table(table) + return table.take(np.arange(start=0, stop=table.num_rows)) + # return table.combine_chunks() + + +def copy_table(table: pa.Table) -> pa.Table: + """Creates a deep-copy of the table""" + import pyarrow as pa + + # concat_arrays forced a deep-copy even if the input arrays only have a single chunk. + return pa.table( + [concat_arrays(column.chunks) for column in table.columns], schema=table.schema + ) + + def concat_tables(tables: Iterable[pa.Table]) -> pa.Table: import pyarrow as pa @@ -100,7 +119,7 @@ def serialize_table(table: pa.Table) -> bytes: stream = pa.BufferOutputStream() with pa.ipc.new_stream(stream, table.schema) as writer: writer.write_table(table) - return stream.getvalue().to_pybytes() + return stream.getvalue() def deserialize_table(buffer: bytes) -> pa.Table: @@ -110,32 +129,32 @@ def deserialize_table(buffer: bytes) -> pa.Table: return reader.read_all() +def write_to_disk(data: list[pa.Table], file: pa.OSFile) -> int: + import pyarrow as pa + + table = concat_tables(data) + del data + start = file.tell() + with pa.ipc.new_stream(file, table.schema) as writer: + writer.write_table(table) + return file.tell() - start + + def read_from_disk(path: Path) -> tuple[list[pa.Table], int]: import pyarrow as pa - batch_size = parse_bytes("1 MiB") - batch = [] shards = [] with pa.OSFile(str(path), mode="rb") as f: size = f.seek(0, whence=2) f.seek(0) - prev = 0 - offset = f.tell() - while offset < size: + while f.tell() < size: sr = pa.RecordBatchStreamReader(f) shard = sr.read_all() - offset = f.tell() - batch.append(shard) - - if offset - prev >= batch_size: - table = concat_tables(batch) - shards.append(_copy_table(table)) - batch = [] - prev = offset - if batch: - table = concat_tables(batch) - shards.append(_copy_table(table)) + shards.append(shard) + + if shards: + shards = [concat_tables(shards)] return shards, size @@ -152,10 +171,3 @@ def concat_arrays(arrays: Iterable[pa.Array]) -> pa.Array: "P2P shuffling requires pyarrow>=12.0.0 to support extension types." ) from e raise - - -def _copy_table(table: pa.Table) -> pa.Table: - import pyarrow as pa - - arrs = [concat_arrays(column.chunks) for column in table.columns] - return pa.table(data=arrs, schema=table.schema) diff --git a/distributed/shuffle/_buffer.py b/distributed/shuffle/_buffer.py index e4a44bc843e..23aa75802b4 100644 --- a/distributed/shuffle/_buffer.py +++ b/distributed/shuffle/_buffer.py @@ -5,16 +5,17 @@ import logging from collections import defaultdict from collections.abc import Iterator, Sized -from typing import TYPE_CHECKING, Any, Generic, TypeVar +from types import TracebackType +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, NamedTuple, TypeVar from distributed.metrics import time from distributed.shuffle._limiter import ResourceLimiter -from distributed.sizeof import sizeof logger = logging.getLogger("distributed.shuffle") if TYPE_CHECKING: # TODO import from collections.abc (requires Python >=3.12) - from typing_extensions import Buffer + # TODO import from typing (requires Python >= 3.10) + from typing_extensions import Buffer, TypeAlias else: Buffer = Sized @@ -23,33 +24,50 @@ T = TypeVar("T") -class ShardsBuffer(Generic[ShardType]): - """A buffer for P2P shuffle +BufferState: TypeAlias = Literal["open", "flushing", "flushed", "erred", "closed"] - The objects to buffer are typically bytes belonging to certain shards. - Typically the buffer is implemented on sending and receiving end. - The buffer allows for concurrent writing and buffers shards to reduce overhead of writing. +class SizedShard(NamedTuple): + shard: Any + size: int + + +class BaseBuffer(Generic[ShardType]): + """Base class for buffers in P2P + + The buffer allows for concurrent writes and buffers shard to reduce overhead. + Writes are non-blocking while the buffer's memory limit is not full. Once full, + buffers become blocking to apply back pressure and limit memory consumption. The shards are typically provided in a format like:: { - "bucket-0": [b"shard1", b"shard2"], - "bucket-1": [b"shard1", b"shard2"], + "bucket-0": [, , ], + "bucket-1": [, ], } - Buckets typically correspond to output partitions. - - If exceptions occur during writing, the buffer is automatically closed. Subsequent attempts to write will raise the same exception. - Flushing will not raise an exception. To ensure that the buffer finished successfully, please call `ShardsBuffer.raise_on_exception` + If exceptions occur during writing, the buffer automatically errs and stores the + exception. Subsequent attempts to write and to flush will raise this exception. + To ensure that the buffer finished successfully, please call + `BaseBuffer.raise_if_erred`. """ - shards: defaultdict[str, list[ShardType]] - sizes: defaultdict[str, int] - sizes_detail: defaultdict[str, list[int]] + #: Whether or not to drain the buffer when flushing + drain: ClassVar[bool] = True + + #: List of buffered shards per output key with their approximate in-memory size + shards: defaultdict[str, list[SizedShard]] + #: Total size of in-memory data per flushable key + flushable_sizes: defaultdict[str, int] + #: Total size of in-memory data per currently flushing key + flushing_sizes: dict[str, int] + #: Limit of concurrent tasks flushing data concurrency_limit: int + #: Plugin-wide resource limiter used to apply back-pressure memory_limiter: ResourceLimiter + #: Diagnostic data diagnostics: dict[str, float] + #: Maximum size of data per key when flushing max_message_size: int bytes_total: int @@ -57,12 +75,14 @@ class ShardsBuffer(Generic[ShardType]): bytes_written: int bytes_read: int - _accepts_input: bool - _inputs_done: bool - _exception: None | Exception + #: State of the buffer + _state: BufferState + #: Exception that occurred while flushing (if exists) + _exception: Exception | None + #: Async condition used for coordination + _flush_condition: asyncio.Condition + #: Background tasks flushing data _tasks: list[asyncio.Task] - _shards_available: asyncio.Condition - _flush_lock: asyncio.Lock def __init__( self, @@ -70,27 +90,27 @@ def __init__( concurrency_limit: int = 2, max_message_size: int = -1, ) -> None: - self._accepts_input = True - self.shards = defaultdict(list) - self.sizes = defaultdict(int) - self.sizes_detail = defaultdict(list) - self._exception = None - self.concurrency_limit = concurrency_limit - self._inputs_done = False + self.shards = defaultdict(list[SizedShard]) + self.flushable_sizes = defaultdict(int) + self.flushing_sizes = {} self.memory_limiter = memory_limiter self.diagnostics: dict[str, float] = defaultdict(float) - self._tasks = [ - asyncio.create_task(self._background_task()) - for _ in range(concurrency_limit) - ] - self._shards_available = asyncio.Condition() - self._flush_lock = asyncio.Lock() self.max_message_size = max_message_size self.bytes_total = 0 self.bytes_memory = 0 self.bytes_written = 0 self.bytes_read = 0 + self._state = "open" + self._exception = None + self._flush_condition = asyncio.Condition() + if self.memory_limiter.limit or self.drain: + self._tasks = [ + asyncio.create_task(self._background_task()) + for _ in range(concurrency_limit) + ] + else: + self._tasks = [] def heartbeat(self) -> dict[str, Any]: return { @@ -103,164 +123,175 @@ def heartbeat(self) -> dict[str, Any]: "memory_limit": self.memory_limiter.limit, } - async def process(self, id: str, shards: list[ShardType], size: int) -> None: - try: - start = time() - try: - await self._process(id, shards) - self.bytes_written += size + async def write(self, data: dict[str, list[SizedShard]]) -> None: + self.raise_if_erred() - except Exception as e: - self._exception = e - self._inputs_done = True - stop = time() - - self.diagnostics["avg_size"] = ( - 0.98 * self.diagnostics["avg_size"] + 0.02 * size + if self._state != "open": + raise RuntimeError( + f"{self} is no longer open for new data, it is {self._state}." ) - self.diagnostics["avg_duration"] = 0.98 * self.diagnostics[ - "avg_duration" - ] + 0.02 * (stop - start) - finally: - await self.memory_limiter.decrease(size) - self.bytes_memory -= size - - async def _process(self, id: str, shards: list[ShardType]) -> None: - raise NotImplementedError() - def read(self, id: str) -> ShardType: - raise NotImplementedError() + if not data: + return - @property - def empty(self) -> bool: - return not self.shards + async with self._flush_condition: + for worker, shards in data.items(): + for shard in shards: + self.shards[worker].append(shard) + size = shard.size + if worker in self.flushing_sizes: + self.flushing_sizes[worker] += size + else: + self.flushable_sizes[worker] += size + self.bytes_memory += size + self.bytes_total += size + self.memory_limiter.increase(size) + del data + self._flush_condition.notify_all() + await self.memory_limiter.wait_for_available() async def _background_task(self) -> None: def _continue() -> bool: - return bool(self.shards or self._inputs_done) + return bool(self.flushable_sizes or self._state != "open") while True: - async with self._shards_available: - await self._shards_available.wait_for(_continue) - if self._inputs_done and not self.shards: + async with self._flush_condition: + await self._flush_condition.wait_for(_continue) + if self._state != "open" and not (self.drain and self.flushable_sizes): break - part_id = max(self.sizes, key=self.sizes.__getitem__) - if self.max_message_size > 0: - size = 0 - shards = [] - while size < self.max_message_size: - try: - shard = self.shards[part_id].pop() - shards.append(shard) - s = self.sizes_detail[part_id].pop() - size += s - self.sizes[part_id] -= s - except IndexError: - break - finally: - if not self.shards[part_id]: - del self.shards[part_id] - assert not self.sizes[part_id] - del self.sizes[part_id] - assert not self.sizes_detail[part_id] - del self.sizes_detail[part_id] - else: - shards = self.shards.pop(part_id) - size = self.sizes.pop(part_id) - self._shards_available.notify_all() - await self.process(part_id, shards, size) - - async def write(self, data: dict[str, ShardType]) -> None: - """ - Writes objects into the local buffers, blocks until ready for more - - Parameters - ---------- - data: dict - A dictionary mapping destinations to the object that should - be written to that destination - - Notes - ----- - If this buffer has a memory limiter configured, then it will - apply back-pressure to the sender (blocking further receives) - if local resource usage hits the limit, until such time as the - resource usage drops. - - """ + await self.flush_largest() - if self._exception: - raise self._exception - if not self._accepts_input or self._inputs_done: - raise RuntimeError(f"Trying to put data in closed {self}.") - - if not data: + async def flush_largest(self) -> None: + if not self.flushable_sizes or self._state not in {"open", "flushing"}: return + largest_partition = max( + self.flushable_sizes, key=self.flushable_sizes.__getitem__ + ) + + partition_size = self.flushable_sizes.pop(largest_partition) + shards: list[ShardType] + + if self.max_message_size > 0: + shards = [] + message_size = 0 + while message_size < self.max_message_size: + try: + shard, size = self.shards[largest_partition].pop() + message_size += size + shards.append(shard) + except IndexError: + break + if message_size == partition_size: + assert not self.shards[largest_partition] + del self.shards[largest_partition] + else: + shards = [shard for shard, _ in self.shards.pop(largest_partition)] + message_size = partition_size - sizes = {worker: sizeof(shard) for worker, shard in data.items()} - total_batch_size = sum(sizes.values()) - self.bytes_memory += total_batch_size - self.bytes_total += total_batch_size - - self.memory_limiter.increase(total_batch_size) - async with self._shards_available: - for worker, shard in data.items(): - self.shards[worker].append(shard) - self.sizes_detail[worker].append(sizes[worker]) - self.sizes[worker] += sizes[worker] - self._shards_available.notify() - await self.memory_limiter.wait_for_available() - del data - assert total_batch_size + self.flushing_sizes[largest_partition] = partition_size - message_size - def raise_on_exception(self) -> None: - """Raises an exception if something went wrong during writing""" - if self._exception: - raise self._exception + start = time() + try: + bytes_written = await self._flush(largest_partition, shards) + if bytes_written is None: + bytes_written = message_size + + self.bytes_memory -= message_size + self.bytes_written += bytes_written + except Exception as e: + if not self._state == "erred": + self._exception = e + self._state = "erred" + self.shards.clear() + flushable_size = sum(self.flushable_sizes.values()) + self.flushable_sizes.clear() + # flushing_size = sum(self.flushing_sizes.values()) + # self.flushing_sizes.clear() + await self.memory_limiter.decrease(flushable_size) + finally: + stop = time() + self.diagnostics["avg_size"] = ( + 0.98 * self.diagnostics["avg_size"] + 0.02 * message_size + ) + self.diagnostics["avg_duration"] = 0.98 * self.diagnostics[ + "avg_duration" + ] + 0.02 * (stop - start) + + async with self._flush_condition: + size_since_flush = self.flushing_sizes.pop(largest_partition) + if self._state in {"open", "flushing"}: + if size_since_flush > 0: + self.flushable_sizes[largest_partition] = size_since_flush + else: + assert self._state == "erred" + await self.memory_limiter.decrease(size_since_flush) + await self.memory_limiter.decrease(message_size) + self._flush_condition.notify_all() async def flush(self) -> None: - """Wait until all writes are finished. - - This closes the buffer such that no new writes are allowed - """ - async with self._flush_lock: - self._accepts_input = False - async with self._shards_available: - self._shards_available.notify_all() - await self._shards_available.wait_for( - lambda: not self.shards or self._exception or self._inputs_done - ) - self._inputs_done = True - self._shards_available.notify_all() + if self._state in {"flushed", "closed"}: + return + if self._state == "flushing": + async with self._flush_condition: + await self._flush_condition.wait_for( + lambda: self._state in {"erred", "flushed", "closed"} + ) + self.raise_if_erred() + return + self.raise_if_erred() + assert self._state == "open", self._state + logger.debug(f"Flushing {self}") + self._state = "flushing" + try: + async with self._flush_condition: + self._flush_condition.notify_all() await asyncio.gather(*self._tasks) - if not self._exception: - assert not self.bytes_memory, (type(self), self.bytes_memory) + self.raise_if_erred() + assert self._state == "flushing" + self._state = "flushed" + logger.debug(f"Successfully flushed {self}") + except Exception: + logger.debug(f"Failed to flush {self}, now in {self._state}") + raise + finally: + async with self._flush_condition: + self._flush_condition.notify_all() async def close(self) -> None: - """Flush and close the buffer. - - This cleans up all allocated resources. - """ - await self.flush() - if not self._exception: - assert not self.bytes_memory, (type(self), self.bytes_memory) - for t in self._tasks: - t.cancel() - self._accepts_input = False - self._inputs_done = True + if self._state == "closed": + return + + try: + await self.flush() + except Exception: + assert self._state == "erred" + await asyncio.gather(*self._tasks, return_exceptions=True) + self._state = "closed" self.shards.clear() + self.flushable_sizes.clear() self.bytes_memory = 0 - async with self._shards_available: - self._shards_available.notify_all() - await asyncio.gather(*self._tasks) + assert not self.flushing_sizes - async def __aenter__(self) -> ShardsBuffer: + async def __aenter__(self) -> BaseBuffer: return self - async def __aexit__(self, exc: Any, typ: Any, traceback: Any) -> None: + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: await self.close() + def raise_if_erred(self) -> None: + if self._exception: + assert self._state == "erred" + raise self._exception + + async def _flush(self, id: str, shards: list[ShardType]) -> int | None: + raise NotImplementedError() + @contextlib.contextmanager def time(self, name: str) -> Iterator[None]: start = time() diff --git a/distributed/shuffle/_comms.py b/distributed/shuffle/_comms.py index 2461bb57a56..1bcd79b2d26 100644 --- a/distributed/shuffle/_comms.py +++ b/distributed/shuffle/_comms.py @@ -3,14 +3,12 @@ from collections.abc import Awaitable, Callable from typing import Any -from dask.utils import parse_bytes - -from distributed.shuffle._disk import ShardsBuffer +from distributed.shuffle._buffer import BaseBuffer from distributed.shuffle._limiter import ResourceLimiter from distributed.utils import log_errors -class CommShardsBuffer(ShardsBuffer): +class CommShardsBuffer(BaseBuffer): """Accept, buffer, and send many small messages to many workers This takes in lots of small messages destined for remote workers, buffers @@ -48,24 +46,25 @@ class CommShardsBuffer(ShardsBuffer): Number of background tasks to run. """ - max_message_size = parse_bytes("2 MiB") + drain = True def __init__( self, send: Callable[[str, list[tuple[Any, Any]]], Awaitable[None]], memory_limiter: ResourceLimiter, + message_bytes_limit: int, concurrency_limit: int = 10, ): super().__init__( memory_limiter=memory_limiter, concurrency_limit=concurrency_limit, - max_message_size=CommShardsBuffer.max_message_size, + max_message_size=message_bytes_limit, ) self.send = send @log_errors - async def _process(self, address: str, shards: list[tuple[Any, Any]]) -> None: + async def _flush(self, id: str, shards: list[Any]) -> int | None: # type: ignore[return] """Send one message off to a neighboring worker""" # Consider boosting total_size a bit here to account for duplication with self.time("send"): - await self.send(address, shards) + await self.send(id, shards) diff --git a/distributed/shuffle/_core.py b/distributed/shuffle/_core.py index 514eb0aab04..0a01afbcd06 100644 --- a/distributed/shuffle/_core.py +++ b/distributed/shuffle/_core.py @@ -20,16 +20,15 @@ import dask.config from dask.core import flatten from dask.typing import Key -from dask.utils import parse_timedelta +from dask.utils import parse_bytes, parse_timedelta from distributed.core import PooledRPCCall from distributed.exceptions import Reschedule from distributed.protocol import to_serialize from distributed.shuffle._comms import CommShardsBuffer -from distributed.shuffle._disk import DiskShardsBuffer from distributed.shuffle._exceptions import ShuffleClosedError from distributed.shuffle._limiter import ResourceLimiter -from distributed.shuffle._memory import MemoryShardsBuffer +from distributed.shuffle._storage import StorageBuffer from distributed.utils import sync from distributed.utils_comm import retry @@ -60,7 +59,7 @@ class ShuffleRun(Generic[_T_partition_id, _T_partition_type]): rpc: Callable[[str], PooledRPCCall] scheduler: PooledRPCCall closed: bool - _disk_buffer: DiskShardsBuffer | MemoryShardsBuffer + _storage_buffer: StorageBuffer _comm_buffer: CommShardsBuffer diagnostics: dict[str, float] received: set[_T_partition_id] @@ -81,6 +80,7 @@ def __init__( local_address: str, directory: str, executor: ThreadPoolExecutor, + io_executor: ThreadPoolExecutor, rpc: Callable[[str], PooledRPCCall], scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, @@ -95,17 +95,23 @@ def __init__( self.rpc = rpc self.scheduler = scheduler self.closed = False - if disk: - self._disk_buffer = DiskShardsBuffer( - directory=directory, - read=self.read, - memory_limiter=memory_limiter_disk, - ) - else: - self._disk_buffer = MemoryShardsBuffer(deserialize=self.deserialize) + self._storage_buffer = StorageBuffer( + directory=directory, + write=self.write, + read=self.read, + memory_limiter=memory_limiter_disk if disk else ResourceLimiter(None), + executor=io_executor, + ) + message_bytes_limit = parse_bytes( + dask.config.get("distributed.p2p.comm.message-bytes-limit") + ) + n_comm_concurrency = dask.config.get("distributed.p2p.comm.concurrency") self._comm_buffer = CommShardsBuffer( - send=self.send, memory_limiter=memory_limiter_comms + send=self.send, + memory_limiter=memory_limiter_comms, + message_bytes_limit=message_bytes_limit, + concurrency_limit=n_comm_concurrency, ) # TODO: reduce number of connections to number of workers # MultiComm.max_connections = min(10, n_workers) @@ -166,15 +172,15 @@ async def _send( async def send( self, address: str, shards: list[tuple[_T_partition_id, Any]] ) -> None: - if _mean_shard_size(shards) < 65536: - # Don't send buffers individually over the tcp comms. - # Instead, merge everything into an opaque bytes blob, send it all at once, - # and unpickle it on the other side. - # Performance tests informing the size threshold: - # https://github.com/dask/distributed/pull/8318 - shards_or_bytes: list | bytes = pickle.dumps(shards) - else: - shards_or_bytes = shards + # if _mean_shard_size(shards) < 65536: + # Don't send buffers individually over the tcp comms. + # Instead, merge everything into an opaque bytes blob, send it all at once, + # and unpickle it on the other side. + # Performance tests informing the size threshold: + # https://github.com/dask/distributed/pull/8318 + # shards_or_bytes: list | bytes = pickle.dumps(shards) + # else: + shards_or_bytes = shards return await retry( partial(self._send, address, shards_or_bytes), @@ -196,21 +202,21 @@ def heartbeat(self) -> dict[str, Any]: comm_heartbeat = self._comm_buffer.heartbeat() comm_heartbeat["read"] = self.total_recvd return { - "disk": self._disk_buffer.heartbeat(), + "disk": self._storage_buffer.heartbeat(), "comm": comm_heartbeat, "diagnostics": self.diagnostics, "start": self.start_time, } async def _write_to_comm( - self, data: dict[str, tuple[_T_partition_id, Any]] + self, data: dict[str, list[tuple[_T_partition_id, Any]]] ) -> None: self.raise_if_closed() await self._comm_buffer.write(data) async def _write_to_disk(self, data: dict[NDIndex, Any]) -> None: self.raise_if_closed() - await self._disk_buffer.write( + await self._storage_buffer.write( {"_".join(str(i) for i in k): v for k, v in data.items()} ) @@ -225,7 +231,7 @@ async def inputs_done(self) -> None: self.transferred = True await self._flush_comm() try: - self._comm_buffer.raise_on_exception() + self._comm_buffer.raise_if_erred() except Exception as e: self._exception = e raise @@ -236,7 +242,7 @@ async def _flush_comm(self) -> None: async def flush_receive(self) -> None: self.raise_if_closed() - await self._disk_buffer.flush() + await self._storage_buffer.flush() async def close(self) -> None: if self.closed: # pragma: no cover @@ -245,7 +251,7 @@ async def close(self) -> None: self.closed = True await self._comm_buffer.close() - await self._disk_buffer.close() + await self._storage_buffer.close() self._closed_event.set() def fail(self, exception: Exception) -> None: @@ -254,7 +260,7 @@ def fail(self, exception: Exception) -> None: def _read_from_disk(self, id: NDIndex) -> list[Any]: # TODO: Typing self.raise_if_closed() - return self._disk_buffer.read("_".join(str(i) for i in id)) + return self._storage_buffer.read("_".join(str(i) for i in id)) async def receive(self, data: list[tuple[_T_partition_id, Any]] | bytes) -> None: if isinstance(data, bytes): @@ -295,7 +301,7 @@ def add_partition( @abc.abstractmethod def _shard_partition( self, data: _T_partition_type, partition_id: _T_partition_id - ) -> dict[str, tuple[_T_partition_id, Any]]: + ) -> dict[str, list[tuple[_T_partition_id, Any]]]: """Shard an input partition by the assigned output workers""" def get_output_partition( @@ -315,12 +321,12 @@ def _get_output_partition( """Get an output partition to the shuffle run""" @abc.abstractmethod - def read(self, path: Path) -> tuple[Any, int]: - """Read shards from disk""" + def write(self, data: list[Any], path: Path) -> int: + """Write shards to disk""" @abc.abstractmethod - def deserialize(self, buffer: Any) -> Any: - """Deserialize shards""" + def read(self, path: Path) -> tuple[Any, int]: + """Read shards from disk""" def get_worker_plugin() -> ShuffleWorkerPlugin: diff --git a/distributed/shuffle/_merge.py b/distributed/shuffle/_merge.py index 99490bde8ec..52420abca83 100644 --- a/distributed/shuffle/_merge.py +++ b/distributed/shuffle/_merge.py @@ -107,7 +107,7 @@ def hash_join_p2p( lhs = _calculate_partitions(lhs, left_on, npartitions) rhs = _calculate_partitions(rhs, right_on, npartitions) merge_name = "hash-join-" + tokenize(lhs, rhs, **merge_kwargs) - disk: bool = dask.config.get("distributed.p2p.disk") + disk: bool = dask.config.get("distributed.p2p.storage.disk") join_layer = HashJoinP2PLayer( name=merge_name, name_input_left=lhs._name, diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index bf576f84745..809d88da150 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -106,6 +106,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, NamedTuple +from toolz import concat from tornado.ioloop import IOLoop import dask @@ -123,7 +124,7 @@ handle_unpack_errors, ) from distributed.shuffle._limiter import ResourceLimiter -from distributed.shuffle._pickle import unpickle_bytestream +from distributed.shuffle._pickle import pickle_bytelist, unpickle_bytestream from distributed.shuffle._scheduler_plugin import ShuffleSchedulerPlugin from distributed.shuffle._shuffle import barrier_key, shuffle_barrier from distributed.shuffle._worker_plugin import ShuffleWorkerPlugin @@ -178,7 +179,7 @@ def rechunk_p2p(x: da.Array, chunks: ChunkedAxes) -> da.Array: token = tokenize(x, chunks) _barrier_key = barrier_key(ShuffleId(token)) name = f"rechunk-transfer-{token}" - disk: bool = dask.config.get("distributed.p2p.disk") + disk: bool = dask.config.get("distributed.p2p.storage.disk") transfer_keys = [] for index in np.ndindex(tuple(len(dim) for dim in x.chunks)): transfer_keys.append((name,) + index) @@ -264,15 +265,14 @@ def split_axes(old: ChunkedAxes, new: ChunkedAxes) -> SplitAxes: return axes -def convert_chunk(shards: list[list[tuple[NDIndex, np.ndarray]]]) -> np.ndarray: +def convert_chunk(shards: list[tuple[NDIndex, np.ndarray]]) -> np.ndarray: import numpy as np from dask.array.core import concatenate3 indexed: dict[NDIndex, np.ndarray] = {} - for sublist in shards: - for index, shard in sublist: - indexed[index] = shard + for index, shard in shards: + indexed[index] = shard subshape = [max(dim) + 1 for dim in zip(*indexed.keys())] assert len(indexed) == np.prod(subshape) @@ -338,6 +338,7 @@ def __init__( local_address: str, directory: str, executor: ThreadPoolExecutor, + io_executor: ThreadPoolExecutor, rpc: Callable[[str], PooledRPCCall], scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, @@ -351,6 +352,7 @@ def __init__( local_address=local_address, directory=directory, executor=executor, + io_executor=io_executor, rpc=rpc, scheduler=scheduler, memory_limiter_comms=memory_limiter_comms, @@ -395,7 +397,7 @@ async def _receive( def _shard_partition( self, data: np.ndarray, partition_id: NDIndex - ) -> dict[str, tuple[NDIndex, Any]]: + ) -> dict[str, list[tuple[NDIndex, Any]]]: out: dict[str, list[tuple[NDIndex, tuple[NDIndex, np.ndarray]]]] = defaultdict( list ) @@ -413,7 +415,7 @@ def _shard_partition( out[self.worker_for[chunk_index]].append( (chunk_index, (shard_index, shard)) ) - return {k: (partition_id, v) for k, v in out.items()} + return {k: [(partition_id, v)] for k, v in out.items()} def _get_output_partition( self, partition_id: NDIndex, key: str, **kwargs: Any @@ -423,11 +425,15 @@ def _get_output_partition( data = self._read_from_disk(partition_id) # Copy the memory-mapped buffers from disk into memory. # This is where we'll spend most time. - with self._disk_buffer.time("read"): + with self._storage_buffer.time("read"): return convert_chunk(data) - def deserialize(self, buffer: Any) -> Any: - return buffer + def write(self, data: list[np.ndarray], path: Path) -> int: + frames = concat(pickle_bytelist(shard) for shard in data) + with path.open(mode="ab") as f: + offset = f.tell() + f.writelines(frames) + return f.tell() - offset def read(self, path: Path) -> tuple[list[list[tuple[NDIndex, np.ndarray]]], int]: """Open a memory-mapped file descriptor to disk, read all metadata, and unpickle @@ -478,6 +484,7 @@ def create_run_on_worker( f"shuffle-{self.id}-{run_id}", ), executor=plugin._executor, + io_executor=plugin._disk_executor, local_address=plugin.worker.address, rpc=plugin.worker.rpc, scheduler=plugin.worker.scheduler, diff --git a/distributed/shuffle/_shuffle.py b/distributed/shuffle/_shuffle.py index 45a0cef4666..0b7c052ca4f 100644 --- a/distributed/shuffle/_shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -25,11 +25,12 @@ check_dtype_support, check_minimal_arrow_version, convert_shards, - deserialize_table, list_of_buffers_to_table, read_from_disk, serialize_table, + write_to_disk, ) +from distributed.shuffle._buffer import SizedShard from distributed.shuffle._core import ( NDIndex, ShuffleId, @@ -124,7 +125,7 @@ def rearrange_by_column_p2p( ) name = f"shuffle_p2p-{token}" - disk: bool = dask.config.get("distributed.p2p.disk") + disk: bool = dask.config.get("distributed.p2p.storage.disk") layer = P2PShuffleLayer( name, @@ -303,6 +304,7 @@ def split_by_worker( Split data into many arrow batches, partitioned by destination worker """ import numpy as np + import pyarrow as pa from dask.dataframe.dispatch import to_pyarrow_table_dispatch @@ -332,11 +334,13 @@ def split_by_worker( splits = np.concatenate([[0], splits]) shards = [ - t.slice(offset=a, length=b - a) for a, b in toolz.sliding_window(2, splits) + t.take(pa.array(np.arange(start=a, stop=b))) + for a, b in toolz.sliding_window(2, splits) ] - shards.append(t.slice(offset=splits[-1], length=None)) - + shards.append(t.take(pa.array(np.arange(start=splits[-1], stop=t.num_rows)))) unique_codes = codes[splits] + del splits + # shards = [copy_table(shard) for shard in shards] out = { # FIXME https://github.com/pandas-dev/pandas-stubs/issues/43 worker_for.cat.categories[code]: shard @@ -351,20 +355,24 @@ def split_by_partition(t: pa.Table, column: str) -> dict[int, pa.Table]: Split data into many arrow batches, partitioned by final partition """ import numpy as np + import pyarrow as pa partitions = t.select([column]).to_pandas()[column].unique() partitions.sort() t = t.sort_by(column) - + nrows = len(t) partition = np.asarray(t[column]) splits = np.where(partition[1:] != partition[:-1])[0] + 1 splits = np.concatenate([[0], splits]) shards = [ - t.slice(offset=a, length=b - a) for a, b in toolz.sliding_window(2, splits) + t.take(pa.array(np.arange(start=a, stop=b))) + for a, b in toolz.sliding_window(2, splits) ] - shards.append(t.slice(offset=splits[-1], length=None)) - assert len(t) == sum(map(len, shards)) + shards.append(t.take(pa.array(np.arange(start=splits[-1], stop=t.num_rows)))) + del t + # shards = [copy_table(shard) for shard in shards] + assert nrows == sum(map(len, shards)) assert len(partitions) == len(shards) return dict(zip(partitions, shards)) @@ -412,6 +420,7 @@ class DataFrameShuffleRun(ShuffleRun[int, "pd.DataFrame"]): meta: pd.DataFrame partitions_of: dict[str, list[int]] worker_for: pd.Series + fds: dict[Path, pa.OSFile] def __init__( self, @@ -423,6 +432,7 @@ def __init__( local_address: str, directory: str, executor: ThreadPoolExecutor, + io_executor: ThreadPoolExecutor, rpc: Callable[[str], PooledRPCCall], scheduler: PooledRPCCall, memory_limiter_disk: ResourceLimiter, @@ -438,6 +448,7 @@ def __init__( local_address=local_address, directory=directory, executor=executor, + io_executor=io_executor, rpc=rpc, scheduler=scheduler, memory_limiter_comms=memory_limiter_comms, @@ -452,6 +463,7 @@ def __init__( partitions_of[addr].append(part) self.partitions_of = dict(partitions_of) self.worker_for = pd.Series(worker_for, name="_workers").astype("category") + self.fds = {} async def _receive(self, data: list[tuple[int, bytes]]) -> None: self.raise_if_closed() @@ -461,7 +473,8 @@ async def _receive(self, data: list[tuple[int, bytes]]) -> None: if d[0] not in self.received: filtered.append(d[1]) self.received.add(d[0]) - self.total_recvd += sizeof(d) + # FIXME + self.total_recvd += sizeof(d[1]) del data if not filtered: return @@ -473,26 +486,30 @@ async def _receive(self, data: list[tuple[int, bytes]]) -> None: self._exception = e raise - def _repartition_buffers(self, data: list[bytes]) -> dict[NDIndex, bytes]: + def _repartition_buffers(self, data: list[bytes]) -> dict[NDIndex, pa.Table]: table = list_of_buffers_to_table(data) groups = split_by_partition(table, self.column) assert len(table) == sum(map(len, groups.values())) del data - return {(k,): serialize_table(v) for k, v in groups.items()} + return { + (k,): [SizedShard(shard=v, size=v.get_total_buffer_size())] + for k, v in groups.items() + } def _shard_partition( self, data: pd.DataFrame, partition_id: int, **kwargs: Any, - ) -> dict[str, tuple[int, bytes]]: + ) -> dict[str, list[tuple[int, bytes]]]: out = split_by_worker( data, self.column, self.meta, self.worker_for, ) - out = {k: (partition_id, serialize_table(t)) for k, t in out.items()} + out = {k: serialize_table(t) for k, t in out.items()} + out = {k: [SizedShard((partition_id, b), b.size)] for k, b in out.items()} return out def _get_output_partition( @@ -510,11 +527,20 @@ def _get_output_partition( def _get_assigned_worker(self, id: int) -> str: return self.worker_for[id] + def write(self, data: list[Any], path: Path) -> int: + return write_to_disk(data, self.get_pa_file(path)) + def read(self, path: Path) -> tuple[pa.Table, int]: return read_from_disk(path) - def deserialize(self, buffer: Any) -> Any: - return deserialize_table(buffer) + def get_pa_file(self, path: Path) -> pa.OSFile: + import pyarrow as pa + + try: + return self.fds[path] + except KeyError: + self.fds[path] = pa.OSFile(str(path), mode="w") + return self.fds[path] @dataclass(frozen=True) @@ -542,6 +568,7 @@ def create_run_on_worker( f"shuffle-{self.id}-{run_id}", ), executor=plugin._executor, + io_executor=plugin._disk_executor, local_address=plugin.worker.address, rpc=plugin.worker.rpc, scheduler=plugin.worker.scheduler, diff --git a/distributed/shuffle/_disk.py b/distributed/shuffle/_storage.py similarity index 78% rename from distributed/shuffle/_disk.py rename to distributed/shuffle/_storage.py index 038a6b987ae..d17b98ba1db 100644 --- a/distributed/shuffle/_disk.py +++ b/distributed/shuffle/_storage.py @@ -1,20 +1,22 @@ from __future__ import annotations +import asyncio import contextlib import pathlib import shutil import threading -from collections.abc import Callable, Generator, Iterable +from collections.abc import Callable, Generator +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -from typing import Any +from typing import TYPE_CHECKING, Any -from toolz import concat - -from distributed.shuffle._buffer import ShardsBuffer +from distributed.shuffle._buffer import BaseBuffer from distributed.shuffle._limiter import ResourceLimiter -from distributed.shuffle._pickle import pickle_bytelist from distributed.utils import Deadline, log_errors +if TYPE_CHECKING: + import pyarrow as pa + class ReadWriteLock: _condition: threading.Condition @@ -92,7 +94,7 @@ def read(self) -> Generator[None, None, None]: self.release_read() -class DiskShardsBuffer(ShardsBuffer): +class StorageBuffer(BaseBuffer): """Accept, buffer, and write many small objects to many files This takes in lots of small objects, writes them to a local directory, and @@ -121,25 +123,30 @@ class DiskShardsBuffer(ShardsBuffer): implementation of this scheme. """ + drain = False + def __init__( self, directory: str | pathlib.Path, + write: Callable[[Any, pathlib.Path], int], read: Callable[[pathlib.Path], tuple[Any, int]], + executor: ThreadPoolExecutor, memory_limiter: ResourceLimiter, ): super().__init__( memory_limiter=memory_limiter, + concurrency_limit=executor._max_workers # Disk is not able to run concurrently atm - concurrency_limit=1, ) - self.directory = pathlib.Path(directory) + self.directory = pathlib.Path(directory).resolve() self.directory.mkdir(exist_ok=True) - self._closed = False + self._write_fn = write self._read = read self._directory_lock = ReadWriteLock() + self._executor = executor @log_errors - async def _process(self, id: str, shards: list[Any]) -> None: + async def _flush(self, id: str, shards: list[pa.Table]) -> int | None: """Write one buffer to file This function was built to offload the disk IO, but since then we've @@ -152,43 +159,41 @@ async def _process(self, id: str, shards: list[Any]) -> None: future then we should consider simplifying this considerably and dropping the write into communicate above. """ + # Consider boosting total_size a bit here to account for duplication with self.time("write"): # We only need shared (i.e., read) access to the directory to write # to a file inside of it. with self._directory_lock.read(): - if self._closed: - raise RuntimeError("Already closed") - - frames: Iterable[bytes | bytearray | memoryview] - - if isinstance(shards[0], bytes): - # Manually serialized dataframes - frames = shards - else: - # Unserialized numpy arrays - frames = concat(pickle_bytelist(shard) for shard in shards) - - with open(self.directory / str(id), mode="ab") as f: - f.writelines(frames) + return await asyncio.get_running_loop().run_in_executor( + self._executor, + self._write_fn, + shards, + (self.directory / str(id)), + ) def read(self, id: str) -> Any: """Read a complete file back into memory""" - self.raise_on_exception() - if not self._inputs_done: - raise RuntimeError("Tried to read from file before done.") + if self._state == "erred": + assert self._exception + raise self._exception + + if not self._state == "flushed": + raise RuntimeError(f"Tried to read from a {self._state} buffer.") try: with self.time("read"): with self._directory_lock.read(): - if self._closed: - raise RuntimeError("Already closed") - data, size = self._read((self.directory / str(id)).resolve()) + if self._state != "flushed": + raise RuntimeError("Can't read") + data, bytes_read = self._read(self.directory / str(id)) + self.bytes_read += bytes_read except FileNotFoundError: - raise KeyError(id) - + data = [] + data += [shard for shard, _ in self.shards.get(id, [])] + bytes_memory = self.flushable_sizes[id] + self.bytes_memory -= bytes_memory if data: - self.bytes_read += size return data else: raise KeyError(id) @@ -196,6 +201,5 @@ def read(self, id: str) -> Any: async def close(self) -> None: await super().close() with self._directory_lock.write(): - self._closed = True with contextlib.suppress(FileNotFoundError): shutil.rmtree(self.directory) diff --git a/distributed/shuffle/_worker_plugin.py b/distributed/shuffle/_worker_plugin.py index 9aa65003ddc..2fe75c99230 100644 --- a/distributed/shuffle/_worker_plugin.py +++ b/distributed/shuffle/_worker_plugin.py @@ -2,10 +2,12 @@ import asyncio import logging -from collections.abc import Sequence +from collections.abc import Generator, Sequence from concurrent.futures import ThreadPoolExecutor +from contextlib import ExitStack, contextmanager from typing import TYPE_CHECKING, Any, overload +import dask from dask.context import thread_state from dask.utils import parse_bytes @@ -266,8 +268,10 @@ class ShuffleWorkerPlugin(WorkerPlugin): memory_limiter_comms: ResourceLimiter memory_limiter_disk: ResourceLimiter closed: bool + _exit_stack: ExitStack def setup(self, worker: Worker) -> None: + self._exit_stack = ExitStack() # Attach to worker worker.handlers["shuffle_receive"] = self.shuffle_receive worker.handlers["shuffle_inputs_done"] = self.shuffle_inputs_done @@ -277,10 +281,37 @@ def setup(self, worker: Worker) -> None: # Initialize self.worker = worker self.shuffle_runs = _ShuffleRunManager(self) - self.memory_limiter_comms = ResourceLimiter(parse_bytes("100 MiB")) - self.memory_limiter_disk = ResourceLimiter(parse_bytes("1 GiB")) + + comm_limit = parse_bytes(dask.config.get("distributed.p2p.comm.buffer")) + self.memory_limiter_comms = ResourceLimiter(comm_limit) + + storage_limit = parse_bytes(dask.config.get("distributed.p2p.storage.buffer")) + self.memory_limiter_disk = ResourceLimiter(storage_limit) self.closed = False - self._executor = ThreadPoolExecutor(self.worker.state.nthreads) + + @contextmanager + def _executor_context( + n_threads: int, thread_name_prefix: str + ) -> Generator[ThreadPoolExecutor, None, None]: + executor = ThreadPoolExecutor( + max_workers=n_threads, thread_name_prefix=thread_name_prefix + ) + try: + yield executor + finally: + try: + executor.shutdown(cancel_futures=True) + except Exception: # pragma: no cover + executor.shutdown() + + n_threads = dask.config.get("distributed.p2p.threads") + self._executor = self._exit_stack.enter_context( + _executor_context(n_threads, thread_name_prefix="P2P-Threads") + ) + n_disk_concurrency = dask.config.get("distributed.p2p.disk.concurrency") + self._disk_executor = self._exit_stack.enter_context( + _executor_context(n_disk_concurrency, thread_name_prefix="P2P-Disk-Threads") + ) def __str__(self) -> str: return f"ShuffleWorkerPlugin on {self.worker.address}" @@ -375,10 +406,7 @@ async def teardown(self, worker: Worker) -> None: self.closed = True await self.shuffle_runs.teardown() - try: - self._executor.shutdown(cancel_futures=True) - except Exception: # pragma: no cover - self._executor.shutdown() + self._exit_stack.close() ############################# # Methods for worker thread # diff --git a/distributed/shuffle/tests/test_buffer.py b/distributed/shuffle/tests/test_buffer.py index a3ad4b07eaf..87c13b6e6d6 100644 --- a/distributed/shuffle/tests/test_buffer.py +++ b/distributed/shuffle/tests/test_buffer.py @@ -8,7 +8,7 @@ from dask.utils import parse_bytes -from distributed.shuffle._buffer import ShardsBuffer +from distributed.shuffle._buffer import BaseBuffer from distributed.shuffle._limiter import ResourceLimiter from distributed.utils import wait_for from distributed.utils_test import gen_test @@ -19,7 +19,7 @@ def gen_bytes(percentage: float, limit: int) -> bytes: return b"0" * num_bytes -class BufferTest(ShardsBuffer): +class BufferTest(BaseBuffer): def __init__(self, memory_limiter: ResourceLimiter, concurrency_limit: int) -> None: self.allow_process = asyncio.Event() self.storage: dict[str, bytes] = defaultdict(bytes) @@ -27,7 +27,7 @@ def __init__(self, memory_limiter: ResourceLimiter, concurrency_limit: int) -> N memory_limiter=memory_limiter, concurrency_limit=concurrency_limit ) - async def _process(self, id: str, shards: list[bytes]) -> None: + async def _flush(self, id: str, shards: list[bytes]) -> None: await self.allow_process.wait() self.storage[id] += b"".join(shards) @@ -41,15 +41,15 @@ def read(self, id: str) -> bytes: @pytest.mark.parametrize( "big_payloads", [ - [{"big": gen_bytes(2, limit)}], - [{"big": gen_bytes(0.5, limit)}] * 4, - [{f"big-{ix}": gen_bytes(0.5, limit)} for ix in range(4)], - [{f"big-{ix}": gen_bytes(0.5, limit)} for ix in range(2)] * 2, + [{"big": [gen_bytes(2, limit)]}], + [{"big": [gen_bytes(0.5, limit)]}] * 4, + [{f"big-{ix}": [gen_bytes(0.5, limit)]} for ix in range(4)], + [{f"big-{ix}": [gen_bytes(0.5, limit)]} for ix in range(2)] * 2, ], ) @gen_test() async def test_memory_limit(big_payloads): - small_payload = {"small": gen_bytes(0.1, limit)} + small_payload = {"small": [gen_bytes(0.1, limit)]} limiter = ResourceLimiter(limit) @@ -101,14 +101,14 @@ async def test_memory_limit(big_payloads): assert before == buf.memory_limiter.time_blocked_total -class BufferShardsBroken(ShardsBuffer): +class BufferShardsBroken(BaseBuffer): def __init__(self, memory_limiter: ResourceLimiter, concurrency_limit: int) -> None: self.storage: dict[str, bytes] = defaultdict(bytes) super().__init__( memory_limiter=memory_limiter, concurrency_limit=concurrency_limit ) - async def _process(self, id: str, shards: list[bytes]) -> None: + async def _flush(self, id: str, shards: list[bytes]) -> None: if id == "error": raise RuntimeError("Error during processing") self.storage[id] += b"".join(shards) @@ -122,10 +122,10 @@ async def test_memory_limit_blocked_exception(): limit = parse_bytes("10.0 MiB") big_payload = { - "shard-1": gen_bytes(2, limit), + "shard-1": [gen_bytes(2, limit)], } broken_payload = { - "error": "not-bytes", + "error": ["not-bytes"], } limiter = ResourceLimiter(limit) async with BufferShardsBroken( @@ -138,7 +138,8 @@ async def test_memory_limit_blocked_exception(): await big_write await small_write - await mf.flush() # Make sure exception is not dropped with pytest.raises(RuntimeError, match="Error during processing"): - mf.raise_on_exception() + await mf.flush() + with pytest.raises(RuntimeError, match="Error during processing"): + mf.raise_if_erred() diff --git a/distributed/shuffle/tests/test_comm_buffer.py b/distributed/shuffle/tests/test_comm_buffer.py index 36896c547d8..1a8a0aeef78 100644 --- a/distributed/shuffle/tests/test_comm_buffer.py +++ b/distributed/shuffle/tests/test_comm_buffer.py @@ -20,9 +20,13 @@ async def test_basic(tmp_path): async def send(address, shards): d[address].extend(shards) - mc = CommShardsBuffer(send=send, memory_limiter=ResourceLimiter(None)) - await mc.write({"x": b"0" * 1000, "y": b"1" * 500}) - await mc.write({"x": b"0" * 1000, "y": b"1" * 500}) + mc = CommShardsBuffer( + send=send, + memory_limiter=ResourceLimiter(None), + message_bytes_limit=parse_bytes("4 MiB"), + ) + await mc.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) + await mc.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) await mc.flush() @@ -37,16 +41,24 @@ async def test_exceptions(tmp_path): async def send(address, shards): raise Exception(123) - mc = CommShardsBuffer(send=send, memory_limiter=ResourceLimiter(None)) - await mc.write({"x": b"0" * 1000, "y": b"1" * 500}) + mc = CommShardsBuffer( + send=send, + memory_limiter=ResourceLimiter(None), + message_bytes_limit=parse_bytes("4 MiB"), + ) + await mc.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) while not mc._exception: await asyncio.sleep(0.1) with pytest.raises(Exception, match="123"): - await mc.write({"x": b"0" * 1000, "y": b"1" * 500}) + await mc.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) - await mc.flush() + with pytest.raises(Exception, match="123"): + await mc.flush() + + with pytest.raises(Exception, match="123"): + mc.raise_if_erred() await mc.close() @@ -64,10 +76,13 @@ async def send(address, shards): sending_first.set() mc = CommShardsBuffer( - send=send, concurrency_limit=1, memory_limiter=ResourceLimiter(None) + send=send, + concurrency_limit=1, + memory_limiter=ResourceLimiter(None), + message_bytes_limit=parse_bytes("4 MiB"), ) - await mc.write({"x": b"0", "y": b"1"}) - await mc.write({"x": b"0", "y": b"1"}) + await mc.write({"x": [b"0"], "y": [b"1"]}) + await mc.write({"x": [b"0"], "y": [b"1"]}) flush_task = asyncio.create_task(mc.flush()) await sending_first.wait() block_send.clear() @@ -95,10 +110,13 @@ async def send(address, shards): nshards = 10 nputs = 20 comm_buffer = CommShardsBuffer( - send=send, memory_limiter=ResourceLimiter(parse_bytes("100 MiB")) + send=send, + memory_limiter=ResourceLimiter(parse_bytes("1 MiB")), + message_bytes_limit=parse_bytes("4 KiB"), ) payload = { - x: gen_bytes(frac, comm_buffer.memory_limiter.limit) for x in range(nshards) + str(x): [gen_bytes(frac, comm_buffer.memory_limiter.limit)] + for x in range(nshards) } async with comm_buffer as mc: @@ -108,13 +126,15 @@ async def send(address, shards): await mc.flush() assert not mc.shards - assert not mc.sizes + assert not mc.flushable_sizes + assert not mc.flushing_sizes assert not mc.shards - assert not mc.sizes + assert not mc.flushable_sizes + assert not mc.flushing_sizes assert len(d) == 10 assert ( - sum(map(len, d[0])) + sum(map(len, d["0"])) == len(gen_bytes(frac, comm_buffer.memory_limiter.limit)) * nputs ) @@ -136,19 +156,23 @@ async def send(address, shards): nshards = 10 nputs = 20 comm_buffer = CommShardsBuffer( - send=send, memory_limiter=ResourceLimiter(parse_bytes("100 MiB")) + send=send, + memory_limiter=ResourceLimiter(parse_bytes("100 MiB")), + message_bytes_limit=parse_bytes("4 MiB"), ) payload = { - x: gen_bytes(frac, comm_buffer.memory_limiter.limit) for x in range(nshards) + str(x): [gen_bytes(frac, comm_buffer.memory_limiter.limit)] + for x in range(nshards) } async with comm_buffer as mc: futs = [asyncio.create_task(mc.write(payload)) for _ in range(nputs)] - await asyncio.gather(*futs) - await mc.flush() with pytest.raises(OSError, match="error during send"): - mc.raise_on_exception() + await mc.flush() + with pytest.raises(OSError, match="error during send"): + mc.raise_if_erred() assert not mc.shards - assert not mc.sizes + assert not mc.flushable_sizes + assert not mc.flushing_sizes diff --git a/distributed/shuffle/tests/test_disk_buffer.py b/distributed/shuffle/tests/test_disk_buffer.py deleted file mode 100644 index 40347a90b92..00000000000 --- a/distributed/shuffle/tests/test_disk_buffer.py +++ /dev/null @@ -1,142 +0,0 @@ -from __future__ import annotations - -import asyncio -import os -from pathlib import Path -from typing import Any - -import pytest - -from distributed.shuffle._disk import DiskShardsBuffer -from distributed.shuffle._limiter import ResourceLimiter -from distributed.utils_test import gen_test - - -def read_bytes(path: Path) -> tuple[bytes, int]: - with path.open("rb") as f: - data = f.read() - size = f.tell() - return data, size - - -@gen_test() -async def test_basic(tmp_path): - async with DiskShardsBuffer( - directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None) - ) as mf: - await mf.write({"x": b"0" * 1000, "y": b"1" * 500}) - await mf.write({"x": b"0" * 1000, "y": b"1" * 500}) - - await mf.flush() - - x = mf.read("x") - y = mf.read("y") - - with pytest.raises(KeyError): - mf.read("z") - - assert x == b"0" * 2000 - assert y == b"1" * 1000 - - assert not os.path.exists(tmp_path) - - -@gen_test() -async def test_read_before_flush(tmp_path): - payload = {"1": b"foo"} - async with DiskShardsBuffer( - directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None) - ) as mf: - with pytest.raises(RuntimeError): - mf.read(1) - - await mf.write(payload) - - with pytest.raises(RuntimeError): - mf.read(1) - - await mf.flush() - assert mf.read("1") == b"foo" - with pytest.raises(KeyError): - mf.read(2) - - -@pytest.mark.parametrize("count", [2, 100, 1000]) -@gen_test() -async def test_many(tmp_path, count): - async with DiskShardsBuffer( - directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None) - ) as mf: - d = {i: str(i).encode() * 100 for i in range(count)} - - for _ in range(10): - await mf.write(d) - - await mf.flush() - - for i in d: - out = mf.read(i) - assert out == str(i).encode() * 100 * 10 - - assert not os.path.exists(tmp_path) - - -class BrokenDiskShardsBuffer(DiskShardsBuffer): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - async def _process(self, *args: Any, **kwargs: Any) -> None: - raise Exception(123) - - -@gen_test() -async def test_exceptions(tmp_path): - async with BrokenDiskShardsBuffer( - directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None) - ) as mf: - await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) - - while not mf._exception: - await asyncio.sleep(0.1) - - with pytest.raises(Exception, match="123"): - await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) - - await mf.flush() - - -class EventuallyBrokenDiskShardsBuffer(DiskShardsBuffer): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.counter = 0 - - async def _process(self, *args: Any, **kwargs: Any) -> None: - # We only want to raise if this was queued up before - if self.counter > self.concurrency_limit: - raise Exception(123) - self.counter += 1 - return await super()._process(*args, **kwargs) - - -@gen_test() -async def test_high_pressure_flush_with_exception(tmp_path): - payload = {f"shard-{ix}": [f"shard-{ix}".encode() * 100] for ix in range(100)} - - async with EventuallyBrokenDiskShardsBuffer( - directory=tmp_path, read=read_bytes, memory_limiter=ResourceLimiter(None) - ) as mf: - tasks = [] - for _ in range(10): - tasks.append(asyncio.create_task(mf.write(payload))) - - # Wait until things are actually queued up. - # This is when there is no slot on the queue available anymore - # but there are still shards around - while not mf.shards: - # Disks are fast, don't give it time to unload the queue... - # There may only be a few ticks atm so keep this at zero - await asyncio.sleep(0) - - with pytest.raises(Exception, match="123"): - await mf.flush() - mf.raise_on_exception() diff --git a/distributed/shuffle/tests/test_memory_buffer.py b/distributed/shuffle/tests/test_memory_buffer.py deleted file mode 100644 index 8182597427c..00000000000 --- a/distributed/shuffle/tests/test_memory_buffer.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -import pytest - -from distributed.shuffle._memory import MemoryShardsBuffer -from distributed.utils_test import gen_test - - -def deserialize_bytes(buffer: bytes) -> bytes: - return buffer - - -@gen_test() -async def test_basic(): - async with MemoryShardsBuffer(deserialize=deserialize_bytes) as mf: - await mf.write({"x": b"0" * 1000, "y": b"1" * 500}) - await mf.write({"x": b"0" * 1000, "y": b"1" * 500}) - - await mf.flush() - - x = mf.read("x") - y = mf.read("y") - - with pytest.raises(KeyError): - mf.read("z") - - assert x == [b"0" * 1000] * 2 - assert y == [b"1" * 500] * 2 - - -@gen_test() -async def test_read_before_flush(): - payload = {"1": b"foo"} - async with MemoryShardsBuffer(deserialize=deserialize_bytes) as mf: - with pytest.raises(RuntimeError): - mf.read("1") - - await mf.write(payload) - - with pytest.raises(RuntimeError): - mf.read("1") - - await mf.flush() - assert mf.read("1") == [b"foo"] - with pytest.raises(KeyError): - mf.read("2") - - -@pytest.mark.parametrize("count", [2, 100, 1000]) -@gen_test() -async def test_many(count): - async with MemoryShardsBuffer(deserialize=deserialize_bytes) as mf: - d = {str(i): str(i).encode() * 100 for i in range(count)} - - for _ in range(10): - await mf.write(d) - - await mf.flush() - - for i in d: - out = mf.read(str(i)) - assert out == [str(i).encode() * 100] * 10 diff --git a/distributed/shuffle/tests/test_merge.py b/distributed/shuffle/tests/test_merge.py index 6c68e5e4267..9129d8afbc9 100644 --- a/distributed/shuffle/tests/test_merge.py +++ b/distributed/shuffle/tests/test_merge.py @@ -166,7 +166,7 @@ async def test_merge(c, s, a, b, how, disk): B = pd.DataFrame({"y": [1, 3, 4, 4, 5, 6], "z": [6, 5, 4, 3, 2, 1]}) b = dd.repartition(B, [0, 2, 5]) - with dask.config.set({"distributed.p2p.disk": disk}): + with dask.config.set({"distributed.p2p.storage.disk": disk}): joined = dd.merge( a, b, left_index=True, right_index=True, how=how, shuffle="p2p" ) diff --git a/distributed/shuffle/tests/test_read_write_lock.py b/distributed/shuffle/tests/test_read_write_lock.py index 4ec1ea1d891..1ef271ce9c2 100644 --- a/distributed/shuffle/tests/test_read_write_lock.py +++ b/distributed/shuffle/tests/test_read_write_lock.py @@ -4,7 +4,7 @@ import pytest -from distributed.shuffle._disk import ReadWriteLock +from distributed.shuffle._storage import ReadWriteLock def read( diff --git a/distributed/shuffle/tests/test_rechunk.py b/distributed/shuffle/tests/test_rechunk.py index d12a5474727..3241de6dd19 100644 --- a/distributed/shuffle/tests/test_rechunk.py +++ b/distributed/shuffle/tests/test_rechunk.py @@ -40,6 +40,7 @@ class ArrayRechunkTestPool(AbstractShuffleTestPool): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._executor = ThreadPoolExecutor(2) + self._io_executor = ThreadPoolExecutor(2) def __enter__(self): return self @@ -49,6 +50,10 @@ def __exit__(self, exc_type, exc_value, traceback): self._executor.shutdown(cancel_futures=True) except Exception: # pragma: no cover self._executor.shutdown() + try: + self._io_executor.shutdown(cancel_futures=True) + except Exception: # pragma: no cover + self._io_executor.shutdown() def new_shuffle( self, @@ -70,6 +75,7 @@ def new_shuffle( run_id=next(AbstractShuffleTestPool._shuffle_run_id_iterator), local_address=name, executor=self._executor, + io_executor=self._io_executor, rpc=self, scheduler=self, memory_limiter_disk=ResourceLimiter(10000000), @@ -213,7 +219,7 @@ async def test_rechunk_2d(c, s, *ws, disk): a = np.random.default_rng().uniform(0, 1, 300).reshape((10, 30)) x = da.from_array(a, chunks=((1, 2, 3, 4), (5,) * 6)) new = ((5, 5), (15,) * 2) - with dask.config.set({"distributed.p2p.disk": disk}): + with dask.config.set({"distributed.p2p.storage.disk": disk}): x2 = rechunk(x, chunks=new, method="p2p") assert x2.chunks == new assert np.all(await c.compute(x2) == a) @@ -237,7 +243,7 @@ async def test_rechunk_4d(c, s, *ws, disk): (10,), (8, 2), ) # This has been altered to return >1 output partition - with dask.config.set({"distributed.p2p.disk": disk}): + with dask.config.set({"distributed.p2p.storge.disk": disk}): x2 = rechunk(x, chunks=new, method="p2p") assert x2.chunks == new await c.compute(x2) @@ -1166,7 +1172,7 @@ async def test_preserve_writeable_flag(c, s, a, b): assert out.tolist() == [True, True] -@gen_cluster(client=True, config={"distributed.p2p.disk": False}) +@gen_cluster(client=True, config={"distributed.p2p.storage.disk": False}) async def test_rechunk_in_memory_shards_dont_share_buffer(c, s, a, b): """Test that, if two shards are sent in the same RPC call and they contribute to different output chunks, downstream tasks don't need to consume all output chunks in @@ -1193,7 +1199,9 @@ def blocked(chunk, in_map, block_map): [run] = a.extensions["shuffle"].shuffle_runs._runs shards = [ - s3 for s1 in run._disk_buffer._shards.values() for s2 in s1 for _, s3 in s2 + shard.shard[1] + for shards in run._storage_buffer.shards.values() + for shard in shards ] assert shards diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index c04ab49f821..8beed46472c 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -191,7 +191,7 @@ async def test_basic_integration(c, s, a, b, npartitions, disk): dtypes={"x": float, "y": float}, freq="10 s", ) - with dask.config.set({"distributed.p2p.disk": disk}): + with dask.config.set({"distributed.p2p.storage.disk": disk}): shuffled = dd.shuffle.shuffle(df, "x", shuffle="p2p", npartitions=npartitions) if npartitions is None: assert shuffled.npartitions == df.npartitions @@ -285,6 +285,7 @@ async def test_concurrent(c, s, a, b): await check_scheduler_cleanup(s) +@pytest.mark.skip() @gen_cluster(client=True) async def test_bad_disk(c, s, a, b): df = dask.datasets.timeseries( @@ -1118,7 +1119,7 @@ def __init__(self, value: int) -> None: assert set(data) == set(worker_for.cat.categories) assert sum(map(len, data.values())) == len(df) - batches = {worker: [serialize_table(t)] for worker, t in data.items()} + batches = {worker: [t] for worker, t in data.items()} # Typically we communicate to different workers at this stage # We then receive them back and reconstute them @@ -1551,6 +1552,7 @@ class DataFrameShuffleTestPool(AbstractShuffleTestPool): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._executor = ThreadPoolExecutor(2) + self._io_executor = ThreadPoolExecutor(2) def __enter__(self): return self @@ -1560,6 +1562,10 @@ def __exit__(self, exc_type, exc_value, traceback): self._executor.shutdown(cancel_futures=True) except Exception: # pragma: no cover self._executor.shutdown() + try: + self._io_executor.shutdown(cancel_futures=True) + except Exception: # pragma: no cover + self._io_executor.shutdown() def new_shuffle( self, @@ -1580,6 +1586,7 @@ def new_shuffle( run_id=next(AbstractShuffleTestPool._shuffle_run_id_iterator), local_address=name, executor=self._executor, + io_executor=self._io_executor, rpc=self, scheduler=self, memory_limiter_disk=ResourceLimiter(10000000), diff --git a/distributed/shuffle/tests/test_storage_buffer.py b/distributed/shuffle/tests/test_storage_buffer.py new file mode 100644 index 00000000000..6a37e340257 --- /dev/null +++ b/distributed/shuffle/tests/test_storage_buffer.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import asyncio +import os +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import Any + +import pytest + +from dask.utils import parse_bytes + +from distributed.shuffle._limiter import ResourceLimiter +from distributed.shuffle._storage import StorageBuffer +from distributed.utils_test import gen_test + + +def write_bytes(data: list[bytes], path: Path) -> int: + with path.open("ab") as f: + offset = f.tell() + f.writelines(data) + return f.tell() - offset + + +def read_bytes(path: Path) -> tuple[list[bytes], int]: + with path.open("rb") as f: + data = f.read() + size = f.tell() + return [data], size + + +# FIXME: Parametrize all tests here +@pytest.mark.parametrize("memory_limit", ["1 B"]) # , "1 KiB", "1 MiB"]) +@gen_test() +async def test_basic(tmp_path, memory_limit): + with ThreadPoolExecutor(2) as executor: + async with StorageBuffer( + directory=tmp_path, + write=write_bytes, + read=read_bytes, + executor=executor, + memory_limiter=ResourceLimiter(parse_bytes(memory_limit)), + ) as mf: + await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) + await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) + + await mf.flush() + + x = mf.read("x") + y = mf.read("y") + + with pytest.raises(KeyError): + mf.read("z") + + assert x == [b"0" * 2000] + assert y == [b"1" * 1000] + + assert not os.path.exists(tmp_path) + + +@gen_test() +async def test_read_before_flush(tmp_path): + payload = {"1": [b"foo"]} + with ThreadPoolExecutor(2) as executor: + async with StorageBuffer( + directory=tmp_path, + write=write_bytes, + read=read_bytes, + executor=executor, + memory_limiter=ResourceLimiter(None), + ) as mf: + with pytest.raises(RuntimeError): + mf.read(1) + + await mf.write(payload) + + with pytest.raises(RuntimeError): + mf.read(1) + + await mf.flush() + assert mf.read("1") == [b"foo"] + with pytest.raises(KeyError): + mf.read(2) + + +@pytest.mark.parametrize("count", [2, 100, 1000]) +@gen_test() +async def test_many(tmp_path, count): + with ThreadPoolExecutor(2) as executor: + async with StorageBuffer( + directory=tmp_path, + write=write_bytes, + read=read_bytes, + executor=executor, + memory_limiter=ResourceLimiter(1), + ) as mf: + d = {i: [str(i).encode() * 100] for i in range(count)} + + for _ in range(10): + await mf.write(d) + + await mf.flush() + + for i in d: + out = mf.read(i) + assert out == [str(i).encode() * 100 * 10] + + assert not os.path.exists(tmp_path) + + +class BrokenStorageBuffer(StorageBuffer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + async def _flush(self, *args: Any, **kwargs: Any) -> None: + raise Exception(123) + + +@gen_test() +async def test_exceptions(tmp_path): + with ThreadPoolExecutor(2) as executor: + async with BrokenStorageBuffer( + directory=tmp_path, + write=write_bytes, + read=read_bytes, + executor=executor, + memory_limiter=ResourceLimiter(1), + ) as mf: + await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) + + while not mf._exception: + await asyncio.sleep(0.1) + + with pytest.raises(Exception, match="123"): + await mf.write({"x": [b"0" * 1000], "y": [b"1" * 500]}) + + with pytest.raises(Exception, match="123"): + await mf.flush() + + +class EventuallyBrokenStorageBuffer(StorageBuffer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.counter = 0 + + async def _flush(self, *args: Any, **kwargs: Any) -> None: + # We only want to raise if this was queued up before + if self.counter > self.concurrency_limit: + raise Exception(123) + self.counter += 1 + await super()._flush(*args, **kwargs) + + +@pytest.mark.skip("Partial flush and this test don't work well together") +@gen_test() +async def test_high_pressure_flush_with_exception(tmp_path): + payload = {f"shard-{ix}": [f"shard-{ix}".encode() * 100] for ix in range(100)} + + with ThreadPoolExecutor(2) as executor: + async with EventuallyBrokenStorageBuffer( + directory=tmp_path, + write=write_bytes, + read=read_bytes, + executor=executor, + memory_limiter=ResourceLimiter(1), + ) as mf: + tasks = [] + for _ in range(10): + tasks.append(asyncio.create_task(mf.write(payload))) + + # Wait until things are actually queued up. + # This is when there is no slot on the queue available anymore + # but there are still shards around + while not mf.shards: + # Disks are fast, don't give it time to unload the queue... + # There may only be a few ticks atm so keep this at zero + await asyncio.sleep(0) + + with pytest.raises(Exception, match="123"): + await mf.flush() + with pytest.raises(Exception, match="123"): + mf.raise_if_erred()