-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add to_icechunk Add to_icechunk * Update APIs * Add docs * Update docs * typo * more docs * typing * More updates * Try again * Add dask to docs * small cleanup * add dask/distributed/xarray min versions * More API sketching * little more * Add tests * Remove extra stuff. * ignore lint * update xarray version * fix * fix bad merge * lint * try again * lint * gen_cluster fixture * fix tests * Update test * Add test for threaded scheduler * Update doc page name * fixup typing * Delete * optimize CI * lint * add comment * fix tests * Apply suggestions from code review Co-authored-by: Ryan Abernathey <[email protected]> * Update icechunk-python/tests/test_xarray.py * Update icechunk-python/python/icechunk/dask.py Co-authored-by: Ryan Abernathey <[email protected]> * Update docs/docs/icechunk-python/dask.md --------- Co-authored-by: Ryan Abernathey <[email protected]>
- Loading branch information
Showing
14 changed files
with
985 additions
and
103 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# Distributed Writes with dask | ||
|
||
You can use Icechunk in conjunction with Xarray and Dask to perform large-scale distributed writes from a multi-node cluster. | ||
However, because of how Icechunk works, it's not possible to use the existing [`Dask.Array.to_zarr`](https://docs.dask.org/en/latest/generated/dask.array.to_zarr.html) or [`Xarray.Dataset.to_zarr`](https://docs.xarray.dev/en/latest/generated/xarray.Dataset.to_zarr.html) functions with either the Dask multiprocessing or distributed schedulers. (It is fine with the multithreaded scheduler.) | ||
|
||
Instead, Icechunk provides its own specialized functions to make distributed writes with Dask and Xarray. | ||
This page explains how to use these specialized functions. | ||
!!! note | ||
|
||
Using Xarray, Dask, and Icechunk requires `icechunk>=0.1.0a5`, `dask>=2024.11.0`, and `xarray>=2024.11.0`. | ||
|
||
|
||
First let's start a distributed Client and create an IcechunkStore. | ||
|
||
```python | ||
# initialize a distributed Client | ||
from distributed import Client | ||
|
||
client = Client() | ||
|
||
# initialize the icechunk store | ||
from icechunk import IcechunkStore, StorageConfig | ||
|
||
storage_config = StorageConfig.filesystem("./icechunk-xarray") | ||
icechunk_store = IcechunkStore.create(storage_config) | ||
``` | ||
|
||
## Icechunk + Dask | ||
|
||
Use [`icechunk.dask.store_dask`](./reference.md#icechunk.dask.store_dask) to write a Dask array to an Icechunk store. | ||
The API follows that of [`dask.array.store`](https://docs.dask.org/en/stable/generated/dask.array.store.html) *without* | ||
support for the `compute` kwarg. | ||
|
||
First create a dask array to write: | ||
```python | ||
shape = (100, 100) | ||
dask_chunks = (20, 20) | ||
dask_array = dask.array.random.random(shape, chunks=dask_chunks) | ||
``` | ||
|
||
Now create the Zarr array you will write to. | ||
```python | ||
zarr_chunks = (10, 10) | ||
group = zarr.group(store=icechunk_store, overwrite=True) | ||
|
||
zarray = group.create_array( | ||
"array", | ||
shape=shape, | ||
chunk_shape=zarr_chunks, | ||
dtype="f8", | ||
fill_value=float("nan"), | ||
) | ||
``` | ||
Note that the chunks in the store are a divisor of the dask chunks. This means each individual | ||
write task is independent, and will not conflict. It is your responsibility to ensure that such | ||
conflicts are avoided. | ||
|
||
Now write | ||
```python | ||
import icechunk.dask | ||
|
||
icechunk.dask.store_dask(icechunk_store, sources=[dask_array], targets=[zarray]) | ||
``` | ||
|
||
Finally commit your changes! | ||
```python | ||
icechunk_store.commit("wrote a dask array!") | ||
``` | ||
|
||
## Icechunk + Dask + Xarray | ||
|
||
### Simple | ||
|
||
The [`icechunk.xarray.to_icechunk`](./reference.md#icechunk.xarray.to_icechunk) is functionally identical to Xarray's | ||
[`Dataset.to_zarr`](https://docs.xarray.dev/en/stable/generated/xarray.Dataset.to_zarr.html), including many of the same keyword arguments. | ||
Notably the ``compute`` kwarg is not supported. | ||
|
||
Now roundtrip an xarray dataset | ||
```python | ||
import icechunk.xarray | ||
import xarray as xr | ||
|
||
dataset = xr.tutorial.open_dataset("rasm", chunks={"time": 1}).isel(time=slice(24)) | ||
|
||
icechunk.xarray.to_icechunk(dataset, store=store) | ||
|
||
roundtripped = xr.open_zarr(store, consolidated=False) | ||
dataset.identical(roundtripped) | ||
``` | ||
|
||
Finally commit your changes! | ||
```python | ||
icechunk_store.commit("wrote an Xarray dataset!") | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
::: icechunk | ||
|
||
::: icechunk.xarray |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
import itertools | ||
from collections.abc import Callable, Iterable, Mapping, MutableMapping, Sequence | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
Literal, | ||
TypeAlias, | ||
overload, | ||
) | ||
|
||
from packaging.version import Version | ||
|
||
import dask | ||
import dask.array | ||
import zarr | ||
from dask import config | ||
from dask.array.core import Array | ||
from dask.base import compute_as_if_collection, tokenize | ||
from dask.core import flatten | ||
from dask.delayed import Delayed | ||
from dask.highlevelgraph import HighLevelGraph | ||
from icechunk import IcechunkStore | ||
from icechunk.distributed import extract_store, merge_stores | ||
|
||
SimpleGraph: TypeAlias = Mapping[tuple[str, int], tuple[Any, ...]] | ||
|
||
|
||
def _assert_correct_dask_version() -> None: | ||
if Version(dask.__version__) < Version("2024.11.0"): | ||
raise ValueError( | ||
f"Writing to icechunk requires dask>=2024.11.0 but you have {dask.__version__}. Please upgrade." | ||
) | ||
|
||
|
||
def store_dask( | ||
store: IcechunkStore, | ||
*, | ||
sources: list[Array], | ||
targets: list[zarr.Array], | ||
regions: list[tuple[slice, ...]] | None = None, | ||
split_every: int | None = None, | ||
**store_kwargs: Any, | ||
) -> None: | ||
stored_arrays = dask.array.store( # type: ignore[attr-defined] | ||
sources=sources, | ||
targets=targets, # type: ignore[arg-type] | ||
regions=regions, | ||
compute=False, | ||
return_stored=True, | ||
load_stored=False, | ||
lock=False, | ||
**store_kwargs, | ||
) | ||
# Now we tree-reduce all changesets | ||
merged_store = stateful_store_reduce( | ||
stored_arrays, | ||
prefix="ice-changeset", | ||
chunk=extract_store, | ||
aggregate=merge_stores, | ||
split_every=split_every, | ||
compute=True, | ||
**store_kwargs, | ||
) | ||
store.merge(merged_store.change_set_bytes()) | ||
|
||
|
||
# tree-reduce all changesets, regardless of array | ||
def partial_reduce( | ||
aggregate: Callable[..., Any], | ||
keys: Iterable[tuple[Any, ...]], | ||
*, | ||
layer_name: str, | ||
split_every: int, | ||
) -> SimpleGraph: | ||
""" | ||
Creates a new dask graph layer, that aggregates `split_every` keys together. | ||
""" | ||
from toolz import partition_all | ||
|
||
return { | ||
(layer_name, i): (aggregate, *keys_batch) | ||
for i, keys_batch in enumerate(partition_all(split_every, keys)) | ||
} | ||
|
||
|
||
@overload | ||
def stateful_store_reduce( | ||
stored_arrays: Sequence[Array], | ||
*, | ||
chunk: Callable[..., Any], | ||
aggregate: Callable[..., Any], | ||
prefix: str | None = None, | ||
split_every: int | None = None, | ||
compute: Literal[False] = False, | ||
**kwargs: Any, | ||
) -> Delayed: ... | ||
|
||
|
||
@overload | ||
def stateful_store_reduce( | ||
stored_arrays: Sequence[Array], | ||
*, | ||
chunk: Callable[..., Any], | ||
aggregate: Callable[..., Any], | ||
compute: Literal[True] = True, | ||
prefix: str | None = None, | ||
split_every: int | None = None, | ||
**kwargs: Any, | ||
) -> IcechunkStore: ... | ||
|
||
|
||
def stateful_store_reduce( | ||
stored_arrays: Sequence[Array], | ||
*, | ||
chunk: Callable[..., Any], | ||
aggregate: Callable[..., Any], | ||
compute: bool = True, | ||
prefix: str | None = None, | ||
split_every: int | None = None, | ||
**kwargs: Any, | ||
) -> IcechunkStore | Delayed: | ||
_assert_correct_dask_version() | ||
|
||
split_every = split_every or config.get("split_every", 8) | ||
|
||
layers: MutableMapping[str, SimpleGraph] = {} | ||
dependencies: MutableMapping[str, set[str]] = {} | ||
|
||
array_names = tuple(a.name for a in stored_arrays) | ||
all_array_keys = list( | ||
# flatten is untyped | ||
itertools.chain(*[flatten(array.__dask_keys__()) for array in stored_arrays]) # type: ignore[no-untyped-call] | ||
) | ||
token = tokenize(array_names, chunk, aggregate, split_every) | ||
|
||
# Each write task returns one Zarr array, | ||
# now extract the changeset (as bytes) from each of those Zarr arrays | ||
map_layer_name = f"{prefix}-blockwise-{token}" | ||
map_dsk: SimpleGraph = { | ||
(map_layer_name, i): (chunk, key) for i, key in enumerate(all_array_keys) | ||
} | ||
layers[map_layer_name] = map_dsk | ||
dependencies[map_layer_name] = set(array_names) | ||
latest_layer = map_layer_name | ||
|
||
if aggregate is not None: | ||
# Now tree-reduce across *all* write tasks, | ||
# regardless of which Array the task belongs to | ||
aggprefix = f"{prefix}-merge" | ||
|
||
depth = 0 | ||
keys = map_dsk.keys() | ||
while len(keys) > split_every: | ||
latest_layer = f"{aggprefix}-{depth}-{token}" | ||
|
||
layers[latest_layer] = partial_reduce( | ||
aggregate, keys, layer_name=latest_layer, split_every=split_every | ||
) | ||
previous_layer, *_ = next(iter(keys)) | ||
dependencies[latest_layer] = {previous_layer} | ||
|
||
keys = layers[latest_layer].keys() | ||
depth += 1 | ||
|
||
# last one | ||
latest_layer = f"{aggprefix}-final-{token}" | ||
layers[latest_layer] = partial_reduce( | ||
aggregate, keys, layer_name=latest_layer, split_every=split_every | ||
) | ||
previous_layer, *_ = next(iter(keys)) | ||
dependencies[latest_layer] = {previous_layer} | ||
|
||
store_dsk = HighLevelGraph.merge( | ||
HighLevelGraph(layers, dependencies), # type: ignore[arg-type] | ||
*[array.__dask_graph__() for array in stored_arrays], | ||
) | ||
if compute: | ||
# copied from dask.array.store | ||
merged_store, *_ = compute_as_if_collection( # type: ignore[no-untyped-call] | ||
Array, store_dsk, list(layers[latest_layer].keys()), **kwargs | ||
) | ||
if TYPE_CHECKING: | ||
assert isinstance(merged_store, IcechunkStore) | ||
return merged_store | ||
|
||
else: | ||
key = "stateful-store-" + tokenize(array_names) | ||
store_dsk = HighLevelGraph.merge( | ||
HighLevelGraph({key: {key: (latest_layer, 0)}}, {key: {latest_layer}}), | ||
store_dsk, | ||
) | ||
return Delayed(key, store_dsk) # type: ignore[no-untyped-call] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# distributed utility functions | ||
from typing import cast | ||
|
||
import zarr | ||
from icechunk import IcechunkStore | ||
|
||
|
||
def extract_store(zarray: zarr.Array) -> IcechunkStore: | ||
return cast(IcechunkStore, zarray.store) | ||
|
||
|
||
def merge_stores(*stores: IcechunkStore) -> IcechunkStore: | ||
store, *rest = stores | ||
for other in rest: | ||
store.merge(other.change_set_bytes()) | ||
return store |
Empty file.
Oops, something went wrong.