Skip to content

Commit

Permalink
Add to_icechunk for xarray (#357)
Browse files Browse the repository at this point in the history
* Add to_icechunk

Add to_icechunk

* Update APIs

* Add docs

* Update docs

* typo

* more docs

* typing

* More updates

* Try again

* Add dask to docs

* small cleanup

* add dask/distributed/xarray min versions

* More API sketching

* little more

* Add tests

* Remove extra stuff.

* ignore lint

* update xarray version

* fix

* fix bad merge

* lint

* try again

* lint

* gen_cluster fixture

* fix tests

* Update test

* Add test for threaded scheduler

* Update doc page name

* fixup typing

* Delete

* optimize CI

* lint

* add comment

* fix tests

* Apply suggestions from code review

Co-authored-by: Ryan Abernathey <[email protected]>

* Update icechunk-python/tests/test_xarray.py

* Update icechunk-python/python/icechunk/dask.py

Co-authored-by: Ryan Abernathey <[email protected]>

* Update docs/docs/icechunk-python/dask.md

---------

Co-authored-by: Ryan Abernathey <[email protected]>
  • Loading branch information
dcherian and rabernat authored Nov 26, 2024
1 parent f17dac6 commit 674864d
Show file tree
Hide file tree
Showing 14 changed files with 985 additions and 103 deletions.
94 changes: 94 additions & 0 deletions docs/docs/icechunk-python/dask.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Distributed Writes with dask

You can use Icechunk in conjunction with Xarray and Dask to perform large-scale distributed writes from a multi-node cluster.
However, because of how Icechunk works, it's not possible to use the existing [`Dask.Array.to_zarr`](https://docs.dask.org/en/latest/generated/dask.array.to_zarr.html) or [`Xarray.Dataset.to_zarr`](https://docs.xarray.dev/en/latest/generated/xarray.Dataset.to_zarr.html) functions with either the Dask multiprocessing or distributed schedulers. (It is fine with the multithreaded scheduler.)

Instead, Icechunk provides its own specialized functions to make distributed writes with Dask and Xarray.
This page explains how to use these specialized functions.
!!! note

Using Xarray, Dask, and Icechunk requires `icechunk>=0.1.0a5`, `dask>=2024.11.0`, and `xarray>=2024.11.0`.


First let's start a distributed Client and create an IcechunkStore.

```python
# initialize a distributed Client
from distributed import Client

client = Client()

# initialize the icechunk store
from icechunk import IcechunkStore, StorageConfig

storage_config = StorageConfig.filesystem("./icechunk-xarray")
icechunk_store = IcechunkStore.create(storage_config)
```

## Icechunk + Dask

Use [`icechunk.dask.store_dask`](./reference.md#icechunk.dask.store_dask) to write a Dask array to an Icechunk store.
The API follows that of [`dask.array.store`](https://docs.dask.org/en/stable/generated/dask.array.store.html) *without*
support for the `compute` kwarg.

First create a dask array to write:
```python
shape = (100, 100)
dask_chunks = (20, 20)
dask_array = dask.array.random.random(shape, chunks=dask_chunks)
```

Now create the Zarr array you will write to.
```python
zarr_chunks = (10, 10)
group = zarr.group(store=icechunk_store, overwrite=True)

zarray = group.create_array(
"array",
shape=shape,
chunk_shape=zarr_chunks,
dtype="f8",
fill_value=float("nan"),
)
```
Note that the chunks in the store are a divisor of the dask chunks. This means each individual
write task is independent, and will not conflict. It is your responsibility to ensure that such
conflicts are avoided.

Now write
```python
import icechunk.dask

icechunk.dask.store_dask(icechunk_store, sources=[dask_array], targets=[zarray])
```

Finally commit your changes!
```python
icechunk_store.commit("wrote a dask array!")
```

## Icechunk + Dask + Xarray

### Simple

The [`icechunk.xarray.to_icechunk`](./reference.md#icechunk.xarray.to_icechunk) is functionally identical to Xarray's
[`Dataset.to_zarr`](https://docs.xarray.dev/en/stable/generated/xarray.Dataset.to_zarr.html), including many of the same keyword arguments.
Notably the ``compute`` kwarg is not supported.

Now roundtrip an xarray dataset
```python
import icechunk.xarray
import xarray as xr

dataset = xr.tutorial.open_dataset("rasm", chunks={"time": 1}).isel(time=slice(24))

icechunk.xarray.to_icechunk(dataset, store=store)

roundtripped = xr.open_zarr(store, consolidated=False)
dataset.identical(roundtripped)
```

Finally commit your changes!
```python
icechunk_store.commit("wrote an Xarray dataset!")
```
1 change: 1 addition & 0 deletions docs/docs/icechunk-python/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
- [examples](/icechunk-python/examples/)
- [notebooks](/icechunk-python/notebooks/)
- [quickstart](/icechunk-python/quickstart/)
- [distributed](/icechunk-python/dask.md)
- [reference](/icechunk-python/reference/)
2 changes: 2 additions & 0 deletions docs/docs/icechunk-python/reference.md
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
::: icechunk

::: icechunk.xarray
3 changes: 3 additions & 0 deletions docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ plugins:
default_handler: python
handlers:
python:
options:
docstring_style: numpy
paths: [../icechunk-python/python]

- mkdocs-jupyter:
Expand Down Expand Up @@ -176,6 +178,7 @@ nav:
- icechunk-python/quickstart.md
- icechunk-python/configuration.md
- icechunk-python/xarray.md
- icechunk-python/dask.md
- icechunk-python/version-control.md
- Virtual Datasets: icechunk-python/virtual.md
- API Reference: icechunk-python/reference.md
Expand Down
4 changes: 3 additions & 1 deletion icechunk-python/examples/dask_write.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""
This example uses Dask to write or update an array in an Icechunk repository.
This example uses Dask as a task orchestration framework
to write or update an array in an Icechunk repository.
To write an Xarray object with dask array use `icechunk.xarray.to_icechunk`
To understand all the available options run:
```
Expand Down
11 changes: 9 additions & 2 deletions icechunk-python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ test = [
"pytest-cov",
"pytest-asyncio",
"ruff",
"dask",
"distributed",
"dask>=2024.11.0",
"distributed>=2024.11.0",
"xarray>=2024.11.0",
"hypothesis",
]

Expand Down Expand Up @@ -64,6 +65,12 @@ strict = true
warn_unreachable = true
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]

[[tool.mypy.overrides]]
ignore_missing_imports = true
module = [
"toolz.*",
]

[tool.ruff]
line-length = 90
extend-exclude = [
Expand Down
192 changes: 192 additions & 0 deletions icechunk-python/python/icechunk/dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import itertools
from collections.abc import Callable, Iterable, Mapping, MutableMapping, Sequence
from typing import (
TYPE_CHECKING,
Any,
Literal,
TypeAlias,
overload,
)

from packaging.version import Version

import dask
import dask.array
import zarr
from dask import config
from dask.array.core import Array
from dask.base import compute_as_if_collection, tokenize
from dask.core import flatten
from dask.delayed import Delayed
from dask.highlevelgraph import HighLevelGraph
from icechunk import IcechunkStore
from icechunk.distributed import extract_store, merge_stores

SimpleGraph: TypeAlias = Mapping[tuple[str, int], tuple[Any, ...]]


def _assert_correct_dask_version() -> None:
if Version(dask.__version__) < Version("2024.11.0"):
raise ValueError(
f"Writing to icechunk requires dask>=2024.11.0 but you have {dask.__version__}. Please upgrade."
)


def store_dask(
store: IcechunkStore,
*,
sources: list[Array],
targets: list[zarr.Array],
regions: list[tuple[slice, ...]] | None = None,
split_every: int | None = None,
**store_kwargs: Any,
) -> None:
stored_arrays = dask.array.store( # type: ignore[attr-defined]
sources=sources,
targets=targets, # type: ignore[arg-type]
regions=regions,
compute=False,
return_stored=True,
load_stored=False,
lock=False,
**store_kwargs,
)
# Now we tree-reduce all changesets
merged_store = stateful_store_reduce(
stored_arrays,
prefix="ice-changeset",
chunk=extract_store,
aggregate=merge_stores,
split_every=split_every,
compute=True,
**store_kwargs,
)
store.merge(merged_store.change_set_bytes())


# tree-reduce all changesets, regardless of array
def partial_reduce(
aggregate: Callable[..., Any],
keys: Iterable[tuple[Any, ...]],
*,
layer_name: str,
split_every: int,
) -> SimpleGraph:
"""
Creates a new dask graph layer, that aggregates `split_every` keys together.
"""
from toolz import partition_all

return {
(layer_name, i): (aggregate, *keys_batch)
for i, keys_batch in enumerate(partition_all(split_every, keys))
}


@overload
def stateful_store_reduce(
stored_arrays: Sequence[Array],
*,
chunk: Callable[..., Any],
aggregate: Callable[..., Any],
prefix: str | None = None,
split_every: int | None = None,
compute: Literal[False] = False,
**kwargs: Any,
) -> Delayed: ...


@overload
def stateful_store_reduce(
stored_arrays: Sequence[Array],
*,
chunk: Callable[..., Any],
aggregate: Callable[..., Any],
compute: Literal[True] = True,
prefix: str | None = None,
split_every: int | None = None,
**kwargs: Any,
) -> IcechunkStore: ...


def stateful_store_reduce(
stored_arrays: Sequence[Array],
*,
chunk: Callable[..., Any],
aggregate: Callable[..., Any],
compute: bool = True,
prefix: str | None = None,
split_every: int | None = None,
**kwargs: Any,
) -> IcechunkStore | Delayed:
_assert_correct_dask_version()

split_every = split_every or config.get("split_every", 8)

layers: MutableMapping[str, SimpleGraph] = {}
dependencies: MutableMapping[str, set[str]] = {}

array_names = tuple(a.name for a in stored_arrays)
all_array_keys = list(
# flatten is untyped
itertools.chain(*[flatten(array.__dask_keys__()) for array in stored_arrays]) # type: ignore[no-untyped-call]
)
token = tokenize(array_names, chunk, aggregate, split_every)

# Each write task returns one Zarr array,
# now extract the changeset (as bytes) from each of those Zarr arrays
map_layer_name = f"{prefix}-blockwise-{token}"
map_dsk: SimpleGraph = {
(map_layer_name, i): (chunk, key) for i, key in enumerate(all_array_keys)
}
layers[map_layer_name] = map_dsk
dependencies[map_layer_name] = set(array_names)
latest_layer = map_layer_name

if aggregate is not None:
# Now tree-reduce across *all* write tasks,
# regardless of which Array the task belongs to
aggprefix = f"{prefix}-merge"

depth = 0
keys = map_dsk.keys()
while len(keys) > split_every:
latest_layer = f"{aggprefix}-{depth}-{token}"

layers[latest_layer] = partial_reduce(
aggregate, keys, layer_name=latest_layer, split_every=split_every
)
previous_layer, *_ = next(iter(keys))
dependencies[latest_layer] = {previous_layer}

keys = layers[latest_layer].keys()
depth += 1

# last one
latest_layer = f"{aggprefix}-final-{token}"
layers[latest_layer] = partial_reduce(
aggregate, keys, layer_name=latest_layer, split_every=split_every
)
previous_layer, *_ = next(iter(keys))
dependencies[latest_layer] = {previous_layer}

store_dsk = HighLevelGraph.merge(
HighLevelGraph(layers, dependencies), # type: ignore[arg-type]
*[array.__dask_graph__() for array in stored_arrays],
)
if compute:
# copied from dask.array.store
merged_store, *_ = compute_as_if_collection( # type: ignore[no-untyped-call]
Array, store_dsk, list(layers[latest_layer].keys()), **kwargs
)
if TYPE_CHECKING:
assert isinstance(merged_store, IcechunkStore)
return merged_store

else:
key = "stateful-store-" + tokenize(array_names)
store_dsk = HighLevelGraph.merge(
HighLevelGraph({key: {key: (latest_layer, 0)}}, {key: {latest_layer}}),
store_dsk,
)
return Delayed(key, store_dsk) # type: ignore[no-untyped-call]
16 changes: 16 additions & 0 deletions icechunk-python/python/icechunk/distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# distributed utility functions
from typing import cast

import zarr
from icechunk import IcechunkStore


def extract_store(zarray: zarr.Array) -> IcechunkStore:
return cast(IcechunkStore, zarray.store)


def merge_stores(*stores: IcechunkStore) -> IcechunkStore:
store, *rest = stores
for other in rest:
store.merge(other.change_set_bytes())
return store
Empty file.
Loading

0 comments on commit 674864d

Please sign in to comment.