-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
173 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from typing import TYPE_CHECKING, Any, List, Sequence, Union | ||
|
||
import zarr | ||
from icechunk import IcechunkStore | ||
|
||
from cubed import compute | ||
from cubed.core.array import CoreArray | ||
from cubed.core.ops import blockwise | ||
from cubed.runtime.types import Callback | ||
|
||
if TYPE_CHECKING: | ||
from cubed.array_api.array_object import Array | ||
|
||
|
||
def store_icechunk( | ||
store: IcechunkStore, | ||
*, | ||
sources: Union["Array", Sequence["Array"]], | ||
targets: List[zarr.Array], | ||
executor=None, | ||
**kwargs: Any, | ||
) -> None: | ||
if isinstance(sources, CoreArray): | ||
sources = [sources] | ||
targets = [targets] | ||
|
||
if any(not isinstance(s, CoreArray) for s in sources): | ||
raise ValueError("All sources must be cubed array objects") | ||
|
||
if len(sources) != len(targets): | ||
raise ValueError( | ||
f"Different number of sources ({len(sources)}) and targets ({len(targets)})" | ||
) | ||
|
||
if isinstance(sources, CoreArray): | ||
sources = [sources] | ||
targets = [targets] | ||
|
||
arrays = [] | ||
for source, target in zip(sources, targets): | ||
identity = lambda a: a | ||
ind = tuple(range(source.ndim)) | ||
array = blockwise( | ||
identity, | ||
ind, | ||
source, | ||
ind, | ||
dtype=source.dtype, | ||
align_arrays=False, | ||
target_store=target, | ||
return_writes_stores=True, | ||
) | ||
arrays.append(array) | ||
|
||
# use a callback to merge icechunk stores | ||
store_callback = IcechunkStoreCallback() | ||
# add to other callbacks the user may have set | ||
callbacks = kwargs.pop("callbacks", []) | ||
callbacks = [store_callback] + list(callbacks) | ||
|
||
compute( | ||
*arrays, | ||
executor=executor, | ||
_return_in_memory_array=False, | ||
callbacks=callbacks, | ||
**kwargs, | ||
) | ||
|
||
# merge back into the store passed into this function | ||
merged_store = store_callback.store | ||
store.merge(merged_store.change_set_bytes()) | ||
|
||
|
||
class IcechunkStoreCallback(Callback): | ||
def on_compute_start(self, event): | ||
self.store = None | ||
|
||
def on_task_end(self, event): | ||
result = event.result | ||
if result is None: | ||
return | ||
for store in result: | ||
if self.store is None: | ||
self.store = store | ||
else: | ||
self.store.merge(store.change_set_bytes()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from typing import Iterable | ||
|
||
import icechunk | ||
import numpy as np | ||
import pytest | ||
import zarr | ||
from numpy.testing import assert_array_equal | ||
|
||
import cubed | ||
import cubed.array_api as xp | ||
import cubed.random | ||
from cubed.icechunk import store_icechunk | ||
from cubed.tests.utils import MAIN_EXECUTORS | ||
|
||
|
||
@pytest.fixture( | ||
scope="module", | ||
params=MAIN_EXECUTORS, | ||
ids=[executor.name for executor in MAIN_EXECUTORS], | ||
) | ||
def executor(request): | ||
return request.param | ||
|
||
|
||
def create_icechunk(a, tmp_path, /, *, dtype=None, chunks=None): | ||
# from dask.asarray | ||
if not isinstance(getattr(a, "shape", None), Iterable): | ||
# ensure blocks are arrays | ||
a = np.asarray(a, dtype=dtype) | ||
if dtype is None: | ||
dtype = a.dtype | ||
|
||
store = icechunk.IcechunkStore.create( | ||
storage=icechunk.StorageConfig.filesystem(tmp_path / "icechunk"), | ||
config=icechunk.StoreConfig(inline_chunk_threshold_bytes=1), | ||
read_only=False, | ||
) | ||
|
||
group = zarr.group(store=store, overwrite=True) | ||
arr = group.create_array("a", shape=a.shape, chunk_shape=chunks, dtype=dtype) | ||
|
||
arr[...] = a | ||
|
||
store.commit("commit 1") | ||
|
||
|
||
def test_from_zarr_icechunk(tmp_path, executor): | ||
create_icechunk( | ||
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], | ||
tmp_path, | ||
chunks=(2, 2), | ||
) | ||
|
||
store = icechunk.IcechunkStore.open_existing( | ||
storage=icechunk.StorageConfig.filesystem(tmp_path / "icechunk"), | ||
) | ||
|
||
a = cubed.from_zarr(store, path="a") | ||
assert_array_equal( | ||
a.compute(executor=executor), np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) | ||
) | ||
|
||
|
||
def test_store_icechunk(tmp_path, executor): | ||
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2)) | ||
|
||
store = icechunk.IcechunkStore.create( | ||
storage=icechunk.StorageConfig.filesystem(tmp_path / "icechunk"), | ||
config=icechunk.StoreConfig(inline_chunk_threshold_bytes=1), | ||
read_only=False, | ||
) | ||
with store.preserve_read_only(): | ||
group = zarr.group(store=store, overwrite=True) | ||
target = group.create_array( | ||
"a", shape=a.shape, chunk_shape=a.chunksize, dtype=a.dtype | ||
) | ||
store_icechunk(store, sources=a, targets=target, executor=executor) | ||
store.commit("commit 1") | ||
|
||
# reopen store and check contents of array | ||
store = icechunk.IcechunkStore.open_existing( | ||
storage=icechunk.StorageConfig.filesystem(tmp_path / "icechunk"), | ||
) | ||
group = zarr.open_group(store=store, mode="r") | ||
assert_array_equal( | ||
cubed.from_array(group["a"])[:], np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) | ||
) |