Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle compound dtypes #19

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 86 additions & 2 deletions lindi/LindiClient/LindiDataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Dict
from typing import Dict, Any
import numpy as np
import zarr
import h5py
import remfile
Expand All @@ -10,6 +11,17 @@ def __init__(self, *, _zarr_array: zarr.Array):
self._zarr_array = _zarr_array
self._is_scalar = self._zarr_array.attrs.get("_SCALAR", False)

# See if we have the _COMPOUND_DTYPE attribute, which signifies that
# this is a compound dtype
compound_dtype_obj = self._zarr_array.attrs.get("_COMPOUND_DTYPE", None)
if compound_dtype_obj is not None:
# If we have a compound dtype, then create the numpy dtype
self._compound_dtype = np.dtype(
[(compound_dtype_obj[i][0], compound_dtype_obj[i][1]) for i in range(len(compound_dtype_obj))]
)
else:
self._compound_dtype = None

self._external_hdf5_clients: Dict[str, h5py.File] = {}

@property
Expand All @@ -35,6 +47,8 @@ def shape(self):

@property
def dtype(self):
if self._compound_dtype is not None:
return self._compound_dtype
return self._zarr_array.dtype

@property
Expand Down Expand Up @@ -74,6 +88,28 @@ def __getitem__(self, selection):
dataset = client[name]
assert isinstance(dataset, h5py.Dataset)
return dataset[selection]
if self._compound_dtype is not None:
# Compound dtype
# In this case we index into the compound dtype using the name of the field
# For example, if the dtype is [('x', 'f4'), ('y', 'f4')], then we can do
# dataset['x'][0] to get the first x value
assert self._compound_dtype.names is not None
if isinstance(selection, str):
# Find the index of this field in the compound dtype
ind = self._compound_dtype.names.index(selection)
# Get the dtype of this field
dtype = np.dtype(self._compound_dtype[ind])
# Return a new object that can be sliced further
# It's important that the return type is Any here, because otherwise we get linter problems
ret: Any = LindiDatasetCompoundFieldSelection(
dataset=self, ind=ind, dtype=dtype
)
return ret
else:
raise TypeError(
f"Compound dataset {self.name} does not support selection with {selection}"
)

# We use zarr's slicing, except in the case of a scalar dataset
if self.ndim == 0:
# make sure selection is ()
Expand All @@ -85,5 +121,53 @@ def __getitem__(self, selection):
def _get_external_hdf5_client(self, url: str) -> h5py.File:
if url not in self._external_hdf5_clients:
remf = remfile.File(url)
self._external_hdf5_clients[url] = h5py.File(remf, 'r')
self._external_hdf5_clients[url] = h5py.File(remf, "r")
return self._external_hdf5_clients[url]


class LindiDatasetCompoundFieldSelection:
"""
This class is returned when a compound dataset is indexed with a field name.
For example, if the dataset has dtype [('x', 'f4'), ('y', 'f4')], then we
can do dataset['x'][0] to get the first x value. The dataset['x'] returns an
object of this class.
"""
def __init__(self, *, dataset: LindiDataset, ind: int, dtype: np.dtype):
self._dataset = dataset # The parent dataset
self._ind = ind # The index of the field in the compound dtype
self._dtype = dtype # The dtype of the field
if self._dataset.ndim != 1:
# For now we only support 1D datasets
raise TypeError(
f"Compound field selection only implemented for 1D datasets, not {self._dataset.ndim}D"
)
# Prepare the data in memory
za = self._dataset._zarr_array
d = [za[i][self._ind] for i in range(len(za))]
self._data = np.array(d, dtype=self._dtype)

def __len__(self):
return self._dataset._zarr_array.shape[0]

def __iter__(self):
for i in range(len(self)):
yield self[i]

@property
def ndim(self):
return self._dataset._zarr_array.ndim

@property
def shape(self):
return self._dataset._zarr_array.shape

@property
def dtype(self):
self._dtype

@property
def size(self):
return self._data.size

def __getitem__(self, selection):
return self._data[selection]
9 changes: 8 additions & 1 deletion lindi/LindiH5Store/LindiH5Store.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,13 @@ def _get_zattrs_bytes(self, parent_key: str):
if isinstance(h5_item, h5py.Dataset):
if h5_item.ndim == 0:
dummy_group.attrs["_SCALAR"] = True
if h5_item.dtype.kind == "V": # compound type
compound_dtype = [
[name, str(h5_item.dtype[name])]
for name in h5_item.dtype.names
]
# For example: [['x', 'uint32'], ['y', 'uint32'], ['weight', 'float32']]
dummy_group.attrs["_COMPOUND_DTYPE"] = compound_dtype
external_array_link = self._get_external_array_link(parent_key, h5_item)
if external_array_link is not None:
dummy_group.attrs["_EXTERNAL_ARRAY_LINK"] = external_array_link
Expand Down Expand Up @@ -506,7 +513,7 @@ def _reformat_json(x: Union[bytes, None]) -> Union[bytes, None]:
if x is None:
return None
a = json.loads(x.decode("utf-8"))
return json.dumps(a, cls=FloatJSONEncoder).encode("utf-8")
return json.dumps(a, cls=FloatJSONEncoder, separators=(",", ":")).encode("utf-8")


# From https://github.com/rly/h5tojson/blob/b162ff7f61160a48f1dc0026acb09adafdb422fa/h5tojson/h5tojson.py#L121-L156
Expand Down
52 changes: 50 additions & 2 deletions lindi/LindiH5Store/_zarr_info_for_h5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _zarr_info_for_h5_dataset(h5_dataset: h5py.Dataset) -> ZarrInfoForH5Dataset:
filters=None,
fill_value=' ',
object_codec=numcodecs.JSON(),
inline_data=json.dumps([value, '|O', [1]]).encode('utf-8')
inline_data=json.dumps([value, '|O', [1]], separators=(',', ':')).encode('utf-8')
)
else:
raise Exception(f'Not yet implemented (1): object scalar dataset with value {value} and dtype {dtype}')
Expand Down Expand Up @@ -124,7 +124,7 @@ def _zarr_info_for_h5_dataset(h5_dataset: h5py.Dataset) -> ZarrInfoForH5Dataset:
data_vec_view[i] = None
else:
raise Exception(f'Cannot handle dataset {h5_dataset.name} with dtype {dtype} and shape {shape}')
inline_data = json.dumps(data.tolist() + ['|O', list(shape)]).encode('utf-8')
inline_data = json.dumps(data.tolist() + ['|O', list(shape)], separators=(',', ':')).encode('utf-8')
return ZarrInfoForH5Dataset(
shape=shape,
chunks=shape, # be explicit about chunks
Expand All @@ -136,10 +136,58 @@ def _zarr_info_for_h5_dataset(h5_dataset: h5py.Dataset) -> ZarrInfoForH5Dataset:
)
elif dtype.kind in 'SU': # byte string or unicode string
raise Exception(f'Not yet implemented (2): dataset {h5_dataset.name} with dtype {dtype} and shape {shape}')
elif dtype.kind == 'V': # void (i.e. compound)
# This is an array representing the compound type
# For example: [['x', 'uint32'], ['y', 'uint32'], ['weight', 'float32']]
compound_dtype = [
[name, str(dtype[name])]
for name in dtype.names
]
if h5_dataset.ndim == 1:
# for now we only handle the case of a 1D compound dataset
data = h5_dataset[:]
# Create an array that would be for example like this
# [[3, 4, 5.3], [2, 1, 7.1], ...]
# where the first entry corresponds to x in the example above, the second to y, and the third to weight
# This is a more compact representation than [{'x': ...}]
# The _COMPOUND_DTYPE attribute will be set on the dataset in the zarr store
# which will be used to interpret the data
array_list = [
[
_json_serialize(data[name][i], type_str)
for name, type_str in compound_dtype
]
for i in range(h5_dataset.shape[0])
]
object_codec = numcodecs.JSON()
inline_data = array_list + ['|O', list(shape)]
return ZarrInfoForH5Dataset(
shape=shape,
chunks=shape, # be explicit about chunks
dtype='object',
filters=None,
fill_value=' ', # not sure what to put here
object_codec=object_codec,
inline_data=json.dumps(inline_data, separators=(',', ':')).encode('utf-8')
)
else:
raise Exception(f'More than one dimension not supported for compound dataset {h5_dataset.name} with dtype {dtype} and shape {shape}')
else:
print(dtype.kind)
raise Exception(f'Not yet implemented (3): dataset {h5_dataset.name} with dtype {dtype} and shape {shape}')


def _json_serialize(val: Any, type_str: str) -> Any:
if type_str.startswith('uint'):
return int(val)
elif type_str.startswith('int'):
return int(val)
elif type_str.startswith('float'):
return float(val)
else:
raise Exception(f'Unable to serialize {val} with type {type_str}')


def _get_numeric_format_str(dtype: Any) -> Union[str, None]:
"""Get the format string for a numeric dtype.

Expand Down
33 changes: 33 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,39 @@ def test_numpy_array_of_strings():
raise ValueError("Arrays are not equal")


def test_compound_dtype():
print("Testing compound dtype")
with tempfile.TemporaryDirectory() as tmpdir:
filename = f"{tmpdir}/test.h5"
with h5py.File(filename, "w") as f:
dt = np.dtype([("x", "i4"), ("y", "f8")])
f.create_dataset("X", data=[(1, 3.14), (2, 6.28)], dtype=dt)
h5f = h5py.File(filename, "r")
store = LindiH5Store.from_file(filename, url=filename)
rfs = store.to_reference_file_system()
client = LindiClient.from_reference_file_system(rfs)
X1 = h5f["X"]
assert isinstance(X1, h5py.Dataset)
X2 = client["X"]
assert isinstance(X2, LindiDataset)
assert X1.shape == X2.shape
assert X1.dtype == X2.dtype
assert X1.size == X2.size
# assert X1.nbytes == X2.nbytes # nbytes are not going to match because the internal representation is different
assert len(X1) == len(X2)
if not _check_equal(X1['x'][:], X2['x'][:]):
print("WARNING. Arrays for x are not equal")
print(X1['x'][:])
print(X2['x'][:])
raise ValueError("Arrays are not equal")
if not _check_equal(X1['y'][:], X2['y'][:]):
print("WARNING. Arrays for y are not equal")
print(X1['y'][:])
print(X2['y'][:])
raise ValueError("Arrays are not equal")
store.close()


def test_attributes():
print("Testing attributes")
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down