Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove case of LindiH5pyFile wrapping h5py.File #52

Merged
merged 6 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 20 additions & 42 deletions lindi/LindiH5pyFile/LindiH5pyAttributes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Literal
from .LindiH5pyReference import LindiH5pyReference
from ..conversion.attr_conversion import zarr_to_h5_attr
from ..conversion.nan_inf_ninf import decode_nan_inf_ninf
Expand All @@ -14,9 +13,8 @@


class LindiH5pyAttributes:
def __init__(self, attrs, attrs_type: Literal["h5py", "zarr"], readonly: bool):
def __init__(self, attrs, readonly: bool):
self._attrs = attrs
self._attrs_type = attrs_type
self._readonly = readonly

if self._readonly:
Expand All @@ -25,43 +23,28 @@ def __init__(self, attrs, attrs_type: Literal["h5py", "zarr"], readonly: bool):
self._writer = LindiH5pyAttributesWriter(self)

def get(self, key, default=None):
if self._attrs_type == "h5py":
return self._attrs.get(key, default)
elif self._attrs_type == "zarr":
try:
if key in _special_attribute_keys:
raise KeyError
return self[key]
except KeyError:
return default
else:
raise ValueError(f"Unknown attrs_type: {self._attrs_type}")
try:
if key in _special_attribute_keys:
raise KeyError
return self[key]
except KeyError:
return default

def __contains__(self, key):
if self._attrs_type == "h5py":
return key in self._attrs
elif self._attrs_type == "zarr":
if key in _special_attribute_keys:
return False
return key in self._attrs
else:
raise ValueError(f"Unknown attrs_type: {self._attrs_type}")
if key in _special_attribute_keys:
return False
return key in self._attrs

def __getitem__(self, key):
val = self._attrs[key]
if self._attrs_type == "h5py":
return val
elif self._attrs_type == "zarr":
if isinstance(val, dict) and "_REFERENCE" in val:
return LindiH5pyReference(val["_REFERENCE"])
if isinstance(val, dict) and "_REFERENCE" in val:
return LindiH5pyReference(val["_REFERENCE"])

# Convert special float values to actual floats (NaN, Inf, -Inf)
# Note that string versions of these values are not supported
val = decode_nan_inf_ninf(val)
# Convert special float values to actual floats (NaN, Inf, -Inf)
# Note that string versions of these values are not supported
val = decode_nan_inf_ninf(val)

return zarr_to_h5_attr(val)
else:
raise ValueError(f"Unknown attrs_type: {self._attrs_type}")
return zarr_to_h5_attr(val)

def __setitem__(self, key, value):
if self._readonly:
Expand All @@ -73,15 +56,10 @@ def __delitem__(self, key):
raise KeyError("Cannot delete attributes on read-only object")

def __iter__(self):
if self._attrs_type == "h5py":
return self._attrs.__iter__()
elif self._attrs_type == "zarr":
# Do not return special zarr attributes during iteration
for k in self._attrs:
if k not in _special_attribute_keys:
yield k
else:
raise ValueError(f"Unknown attrs_type: {self._attrs_type}")
# Do not return special zarr attributes during iteration
for k in self._attrs:
if k not in _special_attribute_keys:
yield k

def items(self):
for k in self:
Expand Down
110 changes: 41 additions & 69 deletions lindi/LindiH5pyFile/LindiH5pyDataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Union, Any, Dict
from typing import TYPE_CHECKING, Any, Dict
import numpy as np
import h5py
import zarr
Expand All @@ -22,45 +22,39 @@


class LindiH5pyDataset(h5py.Dataset):
def __init__(self, _dataset_object: Union[h5py.Dataset, zarr.Array], _file: "LindiH5pyFile"):
self._dataset_object = _dataset_object
def __init__(self, _zarr_array: zarr.Array, _file: "LindiH5pyFile"):
self._zarr_array = _zarr_array
self._file = _file
self._readonly = _file.mode not in ['r+']

# see comment in LindiH5pyGroup
self._id = f'{id(self._file)}/{self._dataset_object.name}'
self._id = f'{id(self._file)}/{self._zarr_array.name}'

# See if we have the _COMPOUND_DTYPE attribute, which signifies that
# this is a compound dtype
if isinstance(_dataset_object, zarr.Array):
compound_dtype_obj = _dataset_object.attrs.get("_COMPOUND_DTYPE", None)
if compound_dtype_obj is not None:
assert isinstance(compound_dtype_obj, list)
# compound_dtype_obj is a list of tuples (name, dtype)
# where dtype == "<REFERENCE>" if it represents an HDF5 reference
for i in range(len(compound_dtype_obj)):
if compound_dtype_obj[i][1] == '<REFERENCE>':
compound_dtype_obj[i][1] = h5py.special_dtype(ref=h5py.Reference)
# If we have a compound dtype, then create the numpy dtype
self._compound_dtype = np.dtype(
[
(
compound_dtype_obj[i][0],
compound_dtype_obj[i][1]
)
for i in range(len(compound_dtype_obj))
]
)
else:
self._compound_dtype = None
compound_dtype_obj = _zarr_array.attrs.get("_COMPOUND_DTYPE", None)
if compound_dtype_obj is not None:
assert isinstance(compound_dtype_obj, list)
# compound_dtype_obj is a list of tuples (name, dtype)
# where dtype == "<REFERENCE>" if it represents an HDF5 reference
for i in range(len(compound_dtype_obj)):
if compound_dtype_obj[i][1] == '<REFERENCE>':
compound_dtype_obj[i][1] = h5py.special_dtype(ref=h5py.Reference)
# If we have a compound dtype, then create the numpy dtype
self._compound_dtype = np.dtype(
[
(
compound_dtype_obj[i][0],
compound_dtype_obj[i][1]
)
for i in range(len(compound_dtype_obj))
]
)
else:
self._compound_dtype = None

# Check whether this is a scalar dataset
if isinstance(_dataset_object, zarr.Array):
self._is_scalar = self._dataset_object.attrs.get("_SCALAR", False)
else:
self._is_scalar = self._dataset_object.ndim == 0
self._is_scalar = self._zarr_array.attrs.get("_SCALAR", False)

# The self._write object handles all the writing operations
from .writers.LindiH5pyDatasetWriter import LindiH5pyDatasetWriter # avoid circular import
Expand All @@ -79,19 +73,19 @@ def id(self):
def shape(self): # type: ignore
if self._is_scalar:
return ()
return self._dataset_object.shape
return self._zarr_array.shape

@property
def size(self):
if self._is_scalar:
return 1
return self._dataset_object.size
return self._zarr_array.size

@property
def dtype(self):
if self._compound_dtype is not None:
return self._compound_dtype
ret = self._dataset_object.dtype
ret = self._zarr_array.dtype
if ret.kind == 'O':
if not ret.metadata:
# The following correction is needed because of
Expand Down Expand Up @@ -127,15 +121,15 @@ def dtype(self):

@property
def nbytes(self):
return self._dataset_object.nbytes
return self._zarr_array.nbytes

@property
def file(self):
return self._file

@property
def name(self):
return self._dataset_object.name
return self._zarr_array.name

@property
def maxshape(self):
Expand All @@ -147,38 +141,22 @@ def maxshape(self):
def ndim(self):
if self._is_scalar:
return 0
return self._dataset_object.ndim
return self._zarr_array.ndim

@property
def attrs(self): # type: ignore
if isinstance(self._dataset_object, h5py.Dataset):
attrs_type = 'h5py'
elif isinstance(self._dataset_object, zarr.Array):
attrs_type = 'zarr'
else:
raise Exception(f'Unexpected dataset object type: {type(self._dataset_object)}')
return LindiH5pyAttributes(self._dataset_object.attrs, attrs_type=attrs_type, readonly=self._file.mode == 'r')
return LindiH5pyAttributes(self._zarr_array.attrs, readonly=self._file.mode == 'r')

@property
def fletcher32(self):
if isinstance(self._dataset_object, h5py.Dataset):
return self._dataset_object.fletcher32
elif isinstance(self._dataset_object, zarr.Array):
for f in self._dataset_object.filters:
if f.__class__.__name__ == 'Fletcher32':
return True
return False
else:
raise Exception(f'Unexpected dataset object type: {type(self._dataset_object)}')
for f in self._zarr_array.filters:
if f.__class__.__name__ == 'Fletcher32':
return True
return False

@property
def chunks(self):
if isinstance(self._dataset_object, h5py.Dataset):
return self._dataset_object.chunks
elif isinstance(self._dataset_object, zarr.Array):
return self._dataset_object.chunks
else:
raise Exception(f'Unexpected dataset object type: {type(self._dataset_object)}')
return self._zarr_array.chunks

def __repr__(self): # type: ignore
return f"<{self.__class__.__name__}: {self.name}>"
Expand All @@ -187,15 +165,9 @@ def __str__(self):
return f"<{self.__class__.__name__}: {self.name}>"

def __getitem__(self, args, new_dtype=None):
if isinstance(self._dataset_object, h5py.Dataset):
ret = self._dataset_object.__getitem__(args, new_dtype)
elif isinstance(self._dataset_object, zarr.Array):
if new_dtype is not None:
raise Exception("new_dtype is not supported for zarr.Array")
ret = self._get_item_for_zarr(self._dataset_object, args)
else:
raise Exception(f"Unexpected type: {type(self._dataset_object)}")
return ret
if new_dtype is not None:
raise Exception("new_dtype is not supported for zarr.Array")
return self._get_item_for_zarr(self._zarr_array, args)

def _get_item_for_zarr(self, zarr_array: zarr.Array, selection: Any):
# First check whether this is an external array link
Expand Down Expand Up @@ -285,11 +257,11 @@ def __init__(self, *, dataset: LindiH5pyDataset, ind: int, dtype: np.dtype):
raise TypeError(
f"Compound field selection only implemented for 1D datasets, not {self._dataset.ndim}D"
)
if not isinstance(self._dataset._dataset_object, zarr.Array):
if not isinstance(self._dataset._zarr_array, zarr.Array):
raise TypeError(
f"Compound field selection only implemented for zarr.Array, not {type(self._dataset._dataset_object)}"
f"Compound field selection only implemented for zarr.Array, not {type(self._dataset._zarr_array)}"
)
za = self._dataset._dataset_object
za = self._dataset._zarr_array
self._zarr_array = za
# Prepare the data in memory
d = [za[i][self._ind] for i in range(len(za))]
Expand Down
Loading