Skip to content

Commit

Permalink
Add store_icechunk
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Dec 4, 2024
1 parent ec36b1a commit a4a84c7
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 0 deletions.
86 changes: 86 additions & 0 deletions cubed/icechunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import TYPE_CHECKING, Any, List, Sequence, Union

import zarr
from icechunk import IcechunkStore

from cubed import compute
from cubed.core.array import CoreArray
from cubed.core.ops import blockwise
from cubed.runtime.types import Callback

if TYPE_CHECKING:
from cubed.array_api.array_object import Array


def store_icechunk(
store: IcechunkStore,
*,
sources: Union["Array", Sequence["Array"]],
targets: List[zarr.Array],
executor=None,
**kwargs: Any,
) -> None:
if isinstance(sources, CoreArray):
sources = [sources]
targets = [targets]

if any(not isinstance(s, CoreArray) for s in sources):
raise ValueError("All sources must be cubed array objects")

if len(sources) != len(targets):
raise ValueError(
f"Different number of sources ({len(sources)}) and targets ({len(targets)})"
)

if isinstance(sources, CoreArray):
sources = [sources]
targets = [targets]

arrays = []
for source, target in zip(sources, targets):
identity = lambda a: a
ind = tuple(range(source.ndim))
array = blockwise(
identity,
ind,
source,
ind,
dtype=source.dtype,
align_arrays=False,
target_store=target,
return_writes_stores=True,
)
arrays.append(array)

# use a callback to merge icechunk stores
store_callback = IcechunkStoreCallback()
# add to other callbacks the user may have set
callbacks = kwargs.pop("callbacks", [])
callbacks = [store_callback] + list(callbacks)

compute(
*arrays,
executor=executor,
_return_in_memory_array=False,
callbacks=callbacks,
**kwargs,
)

# merge back into the store passed into this function
merged_store = store_callback.store
store.merge(merged_store.change_set_bytes())


class IcechunkStoreCallback(Callback):
def on_compute_start(self, event):
self.store = None

def on_task_end(self, event):
result = event.result
if result is None:
return
for store in result:
if self.store is None:
self.store = store
else:
self.store.merge(store.change_set_bytes())
87 changes: 87 additions & 0 deletions cubed/tests/test_icechunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Iterable

import icechunk
import numpy as np
import pytest
import zarr
from numpy.testing import assert_array_equal

import cubed
import cubed.array_api as xp
import cubed.random
from cubed.icechunk import store_icechunk
from cubed.tests.utils import MAIN_EXECUTORS


@pytest.fixture(
scope="module",
params=MAIN_EXECUTORS,
ids=[executor.name for executor in MAIN_EXECUTORS],
)
def executor(request):
return request.param


def create_icechunk(a, tmp_path, /, *, dtype=None, chunks=None):
# from dask.asarray
if not isinstance(getattr(a, "shape", None), Iterable):
# ensure blocks are arrays
a = np.asarray(a, dtype=dtype)
if dtype is None:
dtype = a.dtype

store = icechunk.IcechunkStore.create(
storage=icechunk.StorageConfig.filesystem(tmp_path / "icechunk"),
config=icechunk.StoreConfig(inline_chunk_threshold_bytes=1),
read_only=False,
)

group = zarr.group(store=store, overwrite=True)
arr = group.create_array("a", shape=a.shape, chunk_shape=chunks, dtype=dtype)

arr[...] = a

store.commit("commit 1")


def test_from_zarr_icechunk(tmp_path, executor):
create_icechunk(
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
tmp_path,
chunks=(2, 2),
)

store = icechunk.IcechunkStore.open_existing(
storage=icechunk.StorageConfig.filesystem(tmp_path / "icechunk"),
)

a = cubed.from_zarr(store, path="a")
assert_array_equal(
a.compute(executor=executor), np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
)


def test_store_icechunk(tmp_path, executor):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2))

store = icechunk.IcechunkStore.create(
storage=icechunk.StorageConfig.filesystem(tmp_path / "icechunk"),
config=icechunk.StoreConfig(inline_chunk_threshold_bytes=1),
read_only=False,
)
with store.preserve_read_only():
group = zarr.group(store=store, overwrite=True)
target = group.create_array(
"a", shape=a.shape, chunk_shape=a.chunksize, dtype=a.dtype
)
store_icechunk(store, sources=a, targets=target, executor=executor)
store.commit("commit 1")

# reopen store and check contents of array
store = icechunk.IcechunkStore.open_existing(
storage=icechunk.StorageConfig.filesystem(tmp_path / "icechunk"),
)
group = zarr.open_group(store=store, mode="r")
assert_array_equal(
cubed.from_array(group["a"])[:], np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
)

0 comments on commit a4a84c7

Please sign in to comment.