Skip to content

Commit

Permalink
Zarr Python v3 updates (#523)
Browse files Browse the repository at this point in the history
Remove workaround for zarr-developers/zarr-python#1978
  • Loading branch information
tomwhite authored Jul 29, 2024
1 parent 69e9f94 commit c718787
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 36 deletions.
57 changes: 22 additions & 35 deletions cubed/storage/backends/zarr_python_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import zarr

from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store
from cubed.utils import join_path


class ZarrV3ArrayGroup(dict):
Expand Down Expand Up @@ -40,41 +39,29 @@ def open_zarr_v3_array(
if isinstance(chunks, int):
chunks = (chunks,)

if mode in ("r", "r+"):
# TODO: remove when https://github.com/zarr-developers/zarr-python/issues/1978 is fixed
if mode == "r+":
mode = "w"
if dtype is None or not hasattr(dtype, "fields") or dtype.fields is None:
return zarr.open(store=store, mode=mode, path=path)
if dtype is None or not hasattr(dtype, "fields") or dtype.fields is None:
return zarr.open(
store=store,
mode=mode,
shape=shape,
dtype=dtype,
chunks=chunks,
path=path,
)

group = zarr.open_group(store=store, mode=mode, path=path)

# create/open all the arrays in the group
ret = ZarrV3ArrayGroup(shape=shape, dtype=dtype, chunks=chunks)
for field in dtype.fields:
field_dtype, _ = dtype.fields[field]
if mode in ("r", "r+"):
ret[field] = group[field]
else:
ret = ZarrV3ArrayGroup(shape=shape, dtype=dtype, chunks=chunks)
for field in dtype.fields:
field_dtype, _ = dtype.fields[field]
field_path = field if path is None else join_path(path, field)
ret[field] = zarr.open(store=store, mode=mode, path=field_path)
return ret
else:
overwrite = True if mode == "a" else False
if dtype is None or not hasattr(dtype, "fields") or dtype.fields is None:
return zarr.create(
ret[field] = group.create_array(
field,
shape=shape,
dtype=dtype,
dtype=field_dtype,
chunk_shape=chunks,
store=store,
overwrite=overwrite,
path=path,
)
else:
ret = ZarrV3ArrayGroup(shape=shape, dtype=dtype, chunks=chunks)
for field in dtype.fields:
field_dtype, _ = dtype.fields[field]
field_path = field if path is None else join_path(path, field)
ret[field] = zarr.create(
shape=shape,
dtype=field_dtype,
chunk_shape=chunks,
store=store,
overwrite=overwrite,
path=field_path,
)
return ret
return ret
2 changes: 1 addition & 1 deletion cubed/tests/storage/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def test_lazy_zarr_array(tmp_path):
arr = lazy_zarr_array(zarr_path, shape=(3, 3), dtype=int, chunks=(2, 2))

assert not zarr_path.exists()
with pytest.raises((TypeError, ValueError)):
with pytest.raises((FileNotFoundError, TypeError, ValueError)):
arr.open()

arr.create()
Expand Down

0 comments on commit c718787

Please sign in to comment.