diff --git a/cubed/storage/backends/zarr_python_v3.py b/cubed/storage/backends/zarr_python_v3.py index dd957a4a..f88bc324 100644 --- a/cubed/storage/backends/zarr_python_v3.py +++ b/cubed/storage/backends/zarr_python_v3.py @@ -3,7 +3,6 @@ import zarr from cubed.types import T_DType, T_RegularChunks, T_Shape, T_Store -from cubed.utils import join_path class ZarrV3ArrayGroup(dict): @@ -40,41 +39,29 @@ def open_zarr_v3_array( if isinstance(chunks, int): chunks = (chunks,) - if mode in ("r", "r+"): - # TODO: remove when https://github.com/zarr-developers/zarr-python/issues/1978 is fixed - if mode == "r+": - mode = "w" - if dtype is None or not hasattr(dtype, "fields") or dtype.fields is None: - return zarr.open(store=store, mode=mode, path=path) + if dtype is None or not hasattr(dtype, "fields") or dtype.fields is None: + return zarr.open( + store=store, + mode=mode, + shape=shape, + dtype=dtype, + chunks=chunks, + path=path, + ) + + group = zarr.open_group(store=store, mode=mode, path=path) + + # create/open all the arrays in the group + ret = ZarrV3ArrayGroup(shape=shape, dtype=dtype, chunks=chunks) + for field in dtype.fields: + field_dtype, _ = dtype.fields[field] + if mode in ("r", "r+"): + ret[field] = group[field] else: - ret = ZarrV3ArrayGroup(shape=shape, dtype=dtype, chunks=chunks) - for field in dtype.fields: - field_dtype, _ = dtype.fields[field] - field_path = field if path is None else join_path(path, field) - ret[field] = zarr.open(store=store, mode=mode, path=field_path) - return ret - else: - overwrite = True if mode == "a" else False - if dtype is None or not hasattr(dtype, "fields") or dtype.fields is None: - return zarr.create( + ret[field] = group.create_array( + field, shape=shape, - dtype=dtype, + dtype=field_dtype, chunk_shape=chunks, - store=store, - overwrite=overwrite, - path=path, ) - else: - ret = ZarrV3ArrayGroup(shape=shape, dtype=dtype, chunks=chunks) - for field in dtype.fields: - field_dtype, _ = dtype.fields[field] - field_path = field if path is None else join_path(path, field) - ret[field] = zarr.create( - shape=shape, - dtype=field_dtype, - chunk_shape=chunks, - store=store, - overwrite=overwrite, - path=field_path, - ) - return ret + return ret diff --git a/cubed/tests/storage/test_zarr.py b/cubed/tests/storage/test_zarr.py index fdc189b5..ce27dff8 100644 --- a/cubed/tests/storage/test_zarr.py +++ b/cubed/tests/storage/test_zarr.py @@ -8,7 +8,7 @@ def test_lazy_zarr_array(tmp_path): arr = lazy_zarr_array(zarr_path, shape=(3, 3), dtype=int, chunks=(2, 2)) assert not zarr_path.exists() - with pytest.raises((TypeError, ValueError)): + with pytest.raises((FileNotFoundError, TypeError, ValueError)): arr.open() arr.create()