Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add string and bytes dtypes plus vlen-utf8 and vlen-bytes codecs #2036

Merged
merged 35 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c05b9d1
add legacy vlen-utf8 codec
rabernat Jul 14, 2024
c86ddc6
Merge branch 'v3' into ryan/legacy-vlen
rabernat Sep 29, 2024
a322124
got it working again
rabernat Sep 29, 2024
2a1e2e3
got strings working; broke everything else
rabernat Oct 1, 2024
1d3d7a5
change v3.metadata.data_type type
rabernat Oct 1, 2024
cd40b08
merged
rabernat Oct 1, 2024
988f9df
fixed tests
rabernat Oct 1, 2024
507161a
satisfy mypy for tests
rabernat Oct 1, 2024
1ae5e63
make strings work
rabernat Oct 3, 2024
94ecdb5
add missing module
rabernat Oct 3, 2024
2c7d638
Merge branch 'v3' into ryan/legacy-vlen
d-v-b Oct 3, 2024
b1717d8
Merge remote-tracking branch 'upstream/v3' into ryan/legacy-vlen
rabernat Oct 4, 2024
79b7d43
store -> storage
rabernat Oct 4, 2024
a5c2a37
rename module
rabernat Oct 4, 2024
717f0c7
Merge remote-tracking branch 'origin/ryan/legacy-vlen' into ryan/lega…
rabernat Oct 4, 2024
b90d8f3
merged
rabernat Oct 4, 2024
0406ea1
add vlen bytes
rabernat Oct 7, 2024
8e61a18
fix type assertions in test
rabernat Oct 7, 2024
6cf7dde
much better validation of fill value
rabernat Oct 7, 2024
28d58fa
retype parse_fill_value
rabernat Oct 7, 2024
c6de878
tests pass but not mypy
rabernat Oct 7, 2024
4f026db
attempted to change parse_fill_value typing
rabernat Oct 8, 2024
e427c7a
restore DEFAULT_DTYPE
rabernat Oct 8, 2024
7d9d897
fixup
TomAugspurger Oct 8, 2024
0c21994
docstring
TomAugspurger Oct 8, 2024
c12ac41
update test
TomAugspurger Oct 8, 2024
3aeea1e
add better DataType tests
rabernat Oct 8, 2024
cae7055
more progress on typing; still not passing mypy
rabernat Oct 8, 2024
1aeb49a
fix typing yay!
rabernat Oct 8, 2024
6714bad
make types work with numpy <, 2
rabernat Oct 8, 2024
2edf3b8
Apply suggestions from code review
rabernat Oct 8, 2024
12a0d65
Apply suggestions from code review
rabernat Oct 8, 2024
7ba7077
apply Joe's suggestions
rabernat Oct 8, 2024
1e828b4
add missing module
rabernat Oct 8, 2024
ba0f093
make _STRING_DTYPE private to try to make sphinx happy
rabernat Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
much better validation of fill value
  • Loading branch information
rabernat committed Oct 7, 2024
commit 6cf7dde6214450970661a267f7409217f62e4830
17 changes: 17 additions & 0 deletions src/zarr/codecs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from __future__ import annotations

from typing import Any

import numpy as np

from zarr.codecs.blosc import BloscCname, BloscCodec, BloscShuffle
from zarr.codecs.bytes import BytesCodec, Endian
from zarr.codecs.crc32c_ import Crc32cCodec
Expand All @@ -9,6 +13,7 @@
from zarr.codecs.transpose import TransposeCodec
from zarr.codecs.vlen_utf8 import VLenBytesCodec, VLenUTF8Codec
from zarr.codecs.zstd import ZstdCodec
from zarr.core.metadata.v3 import DataType

__all__ = [
"BatchedCodecPipeline",
Expand All @@ -26,3 +31,15 @@
"VLenBytesCodec",
"ZstdCodec",
]


def get_default_array_bytes_codec(
np_dtype: np.dtype[Any],
) -> BytesCodec | VLenUTF8Codec | VLenBytesCodec:
dtype = DataType.from_numpy(np_dtype)
if dtype == DataType.string:
return VLenUTF8Codec()
elif dtype == DataType.bytes:
return VLenBytesCodec()
else:
return BytesCodec()
4 changes: 2 additions & 2 deletions src/zarr/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from zarr._compat import _deprecate_positional_args
from zarr.abc.store import Store, set_or_delete
from zarr.codecs import BytesCodec
from zarr.codecs import get_default_array_bytes_codec
from zarr.codecs._v2 import V2Compressor, V2Filters
from zarr.core.attributes import Attributes
from zarr.core.buffer import (
Expand Down Expand Up @@ -318,7 +318,7 @@ async def _create_v3(
await ensure_no_existing_node(store_path, zarr_format=3)

shape = parse_shapelike(shape)
codecs = list(codecs) if codecs is not None else [BytesCodec()]
codecs = list(codecs) if codecs is not None else [get_default_array_bytes_codec(dtype)]

if chunk_key_encoding is None:
chunk_key_encoding = ("default", "/")
Expand Down
44 changes: 42 additions & 2 deletions src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,34 @@ def parse_codecs(data: object) -> tuple[Codec, ...]:
return out


def validate_codecs(codecs: tuple[Codec, ...], dtype: DataType) -> None:
"""Check that the codecs are valid for the given dtype"""

# ensure that we have at least one ArrayBytesCodec
abcs: list[ArrayBytesCodec] = []
for codec in codecs:
if isinstance(codec, ArrayBytesCodec):
abcs.append(codec)
if len(abcs) == 0:
raise ValueError("At least one ArrayBytesCodec is required.")
elif len(abcs) > 1:
raise ValueError("Only one ArrayBytesCodec is allowed.")

abc = abcs[0]

# we need to have special codecs if we are decoding vlen strings or bytestrings
# TODO: use codec ID instead of class name
codec_id = abc.__class__.__name__
if dtype == DataType.string and not codec_id == "VLenUTF8Codec":
raise ValueError(
f"For string dtype, ArrayBytesCodec must be `VLenUTF8Codec`, got `{codec_id}`."
)
if dtype == DataType.bytes and not codec_id == "VLenBytesCodec":
raise ValueError(
f"For bytes dtype, ArrayBytesCodec must be `VLenBytesCodec`, got `{codec_id}`."
)


def parse_dimension_names(data: object) -> tuple[str | None, ...] | None:
if data is None:
return data
Expand Down Expand Up @@ -186,6 +214,8 @@ def __init__(
chunk_grid_parsed = ChunkGrid.from_dict(chunk_grid)
chunk_key_encoding_parsed = ChunkKeyEncoding.from_dict(chunk_key_encoding)
dimension_names_parsed = parse_dimension_names(dimension_names)
if fill_value is None:
fill_value = default_fill_value(data_type_parsed)
fill_value_parsed = parse_fill_value(fill_value, dtype=data_type_parsed.to_numpy())
attributes_parsed = parse_attributes(attributes)
codecs_parsed_partial = parse_codecs(codecs)
Expand All @@ -199,6 +229,7 @@ def __init__(
prototype=default_buffer_prototype(), # TODO: prototype is not needed here.
)
codecs_parsed = [c.evolve_from_array_spec(array_spec) for c in codecs_parsed_partial]
validate_codecs(codecs_parsed_partial, data_type_parsed)

object.__setattr__(self, "shape", shape_parsed)
object.__setattr__(self, "data_type", data_type_parsed)
Expand Down Expand Up @@ -360,8 +391,17 @@ def parse_fill_value(
...


def default_fill_value(dtype: DataType) -> str | bytes | np.generic:
if dtype == DataType.string:
return ""
elif dtype == DataType.bytes:
return b""
else:
return dtype.to_numpy().type(0)


def parse_fill_value(
fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool | None,
fill_value: complex | str | bytes | np.generic | Sequence[Any] | bool,
dtype: BOOL_DTYPE | INTEGER_DTYPE | FLOAT_DTYPE | COMPLEX_DTYPE | np.dtype[Any],
) -> BOOL | INTEGER | FLOAT | COMPLEX | Any:
"""
Expand All @@ -385,7 +425,7 @@ def parse_fill_value(
A scalar instance of `dtype`
"""
if fill_value is None:
return dtype.type(0)
raise ValueError("Fill value cannot be None")
if dtype.kind == "O":
return fill_value
if isinstance(fill_value, Sequence) and not isinstance(fill_value, str):
Expand Down
48 changes: 40 additions & 8 deletions tests/v3/test_codecs/test_vlen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import pytest

from zarr import Array
from zarr.abc.codec import Codec
from zarr.abc.store import Store
from zarr.codecs import VLenBytesCodec, VLenUTF8Codec
from zarr.codecs import VLenBytesCodec, VLenUTF8Codec, ZstdCodec
from zarr.core.metadata.v3 import ArrayV3Metadata, DataType
from zarr.storage.common import StorePath
from zarr.strings import NUMPY_SUPPORTS_VLEN_STRING
Expand All @@ -21,11 +22,13 @@

@pytest.mark.parametrize("store", ["memory", "local"], indirect=["store"])
@pytest.mark.parametrize("dtype", numpy_str_dtypes)
async def test_vlen_string(store: Store, dtype: None | np.dtype[Any]) -> None:
@pytest.mark.parametrize("as_object_array", [False, True])
@pytest.mark.parametrize("codecs", [None, [VLenUTF8Codec()], [VLenUTF8Codec(), ZstdCodec()]])
def test_vlen_string(
store: Store, dtype: None | np.dtype[Any], as_object_array: bool, codecs: None | list[Codec]
) -> None:
strings = ["hello", "world", "this", "is", "a", "test"]
data = np.array(strings).reshape((2, 3))
if dtype is not None:
data = data.astype(dtype)
data = np.array(strings, dtype=dtype).reshape((2, 3))

sp = StorePath(store, path="string")
a = Array.create(
Expand All @@ -34,10 +37,15 @@ async def test_vlen_string(store: Store, dtype: None | np.dtype[Any]) -> None:
chunk_shape=data.shape,
dtype=data.dtype,
fill_value="",
codecs=[VLenUTF8Codec()],
codecs=codecs,
)
assert isinstance(a.metadata, ArrayV3Metadata) # needed for mypy

# should also work if input array is an object array, provided we explicitly specified
# a stringlike dtype when creating the Array
if as_object_array:
data = data.astype("O")

a[:, :] = data
assert np.array_equal(data, a[:, :])
assert a.metadata.data_type == DataType.string
Expand All @@ -52,7 +60,9 @@ async def test_vlen_string(store: Store, dtype: None | np.dtype[Any]) -> None:


@pytest.mark.parametrize("store", ["memory", "local"], indirect=["store"])
async def test_vlen_bytes(store: Store) -> None:
@pytest.mark.parametrize("as_object_array", [False, True])
@pytest.mark.parametrize("codecs", [None, [VLenBytesCodec()], [VLenBytesCodec(), ZstdCodec()]])
def test_vlen_bytes(store: Store, as_object_array: bool, codecs: None | list[Codec]) -> None:
bstrings = [b"hello", b"world", b"this", b"is", b"a", b"test"]
data = np.array(bstrings).reshape((2, 3))
assert data.dtype == "|S5"
Expand All @@ -64,10 +74,14 @@ async def test_vlen_bytes(store: Store) -> None:
chunk_shape=data.shape,
dtype=data.dtype,
fill_value=b"",
codecs=[VLenBytesCodec()],
codecs=codecs,
)
assert isinstance(a.metadata, ArrayV3Metadata) # needed for mypy

# should also work if input array is an object array, provided we explicitly specified
# a bytesting-like dtype when creating the Array
if as_object_array:
data = data.astype("O")
a[:, :] = data
assert np.array_equal(data, a[:, :])
assert a.metadata.data_type == DataType.bytes
Expand All @@ -79,3 +93,21 @@ async def test_vlen_bytes(store: Store) -> None:
assert np.array_equal(data, b[:, :])
assert b.metadata.data_type == DataType.bytes
assert a.dtype == "O"


@pytest.mark.parametrize("store", ["memory"], indirect=["store"])
def test_vlen_errors(store: Store) -> None:
sp = StorePath(store, path="string")

# fill value must be a compatible type
with pytest.raises(ValueError, match="fill value 0 is not valid"):
Array.create(sp, shape=5, chunk_shape=5, dtype="<U4", fill_value=0)

# FIXME: this should raise but doesn't; need to fix parse_fill_value
# Problem is that parse_fill_value compares with numpy dtype('O') instead
# of DataType.bytes, and anything can be cast to Object
# with pytest.raises(ValueError, match="fill value X is not valid"):
# Array.create(sp, shape=5, chunk_shape=5, dtype='|S4', fill_value='')

a = Array.create(sp, shape=5, chunk_shape=5, dtype="<U4")
assert a.fill_value == ""
19 changes: 13 additions & 6 deletions tests/v3/test_metadata/test_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest

from zarr.core.metadata.v3 import (
default_fill_value,
parse_dimension_names,
parse_fill_value,
parse_zarr_format,
Expand All @@ -46,8 +47,9 @@
)

complex_dtypes = ("complex64", "complex128")
vlen_dtypes = ("string", "bytes")

dtypes = (*bool_dtypes, *int_dtypes, *float_dtypes, *complex_dtypes)
dtypes = (*bool_dtypes, *int_dtypes, *float_dtypes, *complex_dtypes, *vlen_dtypes)


@pytest.mark.parametrize("data", [None, 1, 2, 4, 5, "3"])
Expand All @@ -72,13 +74,18 @@ def parse_dimension_names_valid(data: Sequence[str] | None) -> None:


@pytest.mark.parametrize("dtype_str", dtypes)
def test_parse_auto_fill_value(dtype_str: str) -> None:
def test_default_fill_value(dtype_str: str) -> None:
"""
Test that parse_fill_value(None, dtype) results in the 0 value for the given dtype.
"""
dtype = np.dtype(dtype_str)
fill_value = None
assert parse_fill_value(fill_value, dtype) == dtype.type(0)
dtype = DataType(dtype_str)
fill_value = default_fill_value(dtype)
if dtype == DataType.string:
assert fill_value == ""
elif dtype == DataType.bytes:
assert fill_value == b""
else:
assert fill_value == dtype.to_numpy().type(0)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -337,7 +344,7 @@ async def test_special_float_fill_values(fill_value: str) -> None:
"chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
"data_type": "float64",
"chunk_key_encoding": {"name": "default", "separator": "."},
"codecs": (),
"codecs": [{"name": "bytes"}],
"fill_value": fill_value, # this is not a valid fill value for uint8
}
m = ArrayV3Metadata.from_dict(metadata_dict)
Expand Down