Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
magland committed Mar 20, 2024
1 parent fc3978a commit f1ce861
Show file tree
Hide file tree
Showing 14 changed files with 498 additions and 541 deletions.
81 changes: 67 additions & 14 deletions lindi/LindiH5pyFile/LindiH5pyAttributes.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,81 @@
from typing import Union, Any
from typing import Literal
from .LindiH5pyReference import LindiH5pyReference
from ..LindiZarrWrapper import LindiZarrWrapperAttributes, LindiZarrWrapperReference

_special_attribute_keys = [
"_SCALAR",
"_COMPOUND_DTYPE",
"_REFERENCE",
"_EXTERNAL_ARRAY_LINK",
"_SOFT_LINK",
]


class LindiH5pyAttributes:
def __init__(self, attrs: Union[Any, LindiZarrWrapperAttributes]):
def __init__(self, attrs, attrs_type: Literal["h5py", "zarr"]):
self._attrs = attrs
self._attrs_type = attrs_type

def get(self, key, default=None):
if self._attrs_type == "h5py":
return self._attrs.get(key, default)
elif self._attrs_type == "zarr":
try:
if key in _special_attribute_keys:
raise KeyError
return self[key]
except KeyError:
return default
else:
raise ValueError(f"Unknown attrs_type: {self._attrs_type}")

def __contains__(self, key):
if self._attrs_type == "h5py":
return key in self._attrs
elif self._attrs_type == "zarr":
return key in self._attrs
else:
raise ValueError(f"Unknown attrs_type: {self._attrs_type}")

def __getitem__(self, key):
val = self._attrs[key]
if isinstance(val, LindiZarrWrapperReference):
return LindiH5pyReference(val)
return val
if self._attrs_type == "h5py":
return val
elif self._attrs_type == "zarr":
if isinstance(val, dict) and "_REFERENCE" in val:
return LindiH5pyReference(val["_REFERENCE"])
else:
return val
else:
raise ValueError(f"Unknown attrs_type: {self._attrs_type}")

def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
def __setitem__(self, key, value):
raise KeyError("Cannot set attributes on read-only object")

def __delitem__(self, key):
raise KeyError("Cannot delete attributes on read-only object")

def __iter__(self):
if self._attrs_type == "h5py":
return self._attrs.__iter__()
elif self._attrs_type == "zarr":
for k in self._attrs:
if k not in _special_attribute_keys:
yield k
else:
raise ValueError(f"Unknown attrs_type: {self._attrs_type}")

def items(self):
for k in self:
yield k, self[k]

def __iter__(self):
for k in self._attrs:
yield k
def __len__(self):
ct = 0
for _ in self:
ct += 1
return ct

def __repr__(self):
return repr(self._attrs)

def __str__(self):
return str(self._attrs)
183 changes: 171 additions & 12 deletions lindi/LindiH5pyFile/LindiH5pyDataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import TYPE_CHECKING, Union, Any
import h5py
from typing import TYPE_CHECKING, Union, Any, Dict
import numpy as np
import h5py
import zarr
import remfile

from .LindiH5pyAttributes import LindiH5pyAttributes
from .LindiH5pyReference import LindiH5pyReference
from ..LindiZarrWrapper import LindiZarrWrapperDataset
from ..LindiZarrWrapper import LindiZarrWrapperReference


if TYPE_CHECKING:
Expand All @@ -18,24 +18,55 @@ def __init__(self, _h5py_dataset_id):


class LindiH5pyDataset(h5py.Dataset):
def __init__(self, _dataset_object: Union[h5py.Dataset, LindiZarrWrapperDataset], _file: "LindiH5pyFile"):
def __init__(self, _dataset_object: Union[h5py.Dataset, zarr.Array], _file: "LindiH5pyFile"):
self._dataset_object = _dataset_object
self._file = _file

self._external_hdf5_clients: Dict[str, h5py.File] = {}

# See if we have the _COMPOUND_DTYPE attribute, which signifies that
# this is a compound dtype
if isinstance(_dataset_object, zarr.Array):
compound_dtype_obj = _dataset_object.attrs.get("_COMPOUND_DTYPE", None)
if compound_dtype_obj is not None:
# If we have a compound dtype, then create the numpy dtype
self._compound_dtype = np.dtype(
[(compound_dtype_obj[i][0], compound_dtype_obj[i][1]) for i in range(len(compound_dtype_obj))]
)
else:
self._compound_dtype = None
else:
self._compound_dtype = None

# Check whether this is a scalar dataset
if isinstance(_dataset_object, zarr.Array):
self._is_scalar = self._dataset_object.attrs.get("_SCALAR", False)
else:
self._is_scalar = self._dataset_object.ndim == 0

@property
def id(self):
return LindiH5pyDatasetId(self._dataset_object.id)
if isinstance(self._dataset_object, h5py.Dataset):
return LindiH5pyDatasetId(self._dataset_object.id)
else:
return LindiH5pyDatasetId(None)

@property
def shape(self): # type: ignore
if self._is_scalar:
return ()
return self._dataset_object.shape

@property
def size(self):
if self._is_scalar:
return 1
return self._dataset_object.size

@property
def dtype(self):
if self._compound_dtype is not None:
return self._compound_dtype
return self._dataset_object.dtype

@property
Expand All @@ -58,28 +89,93 @@ def maxshape(self):

@property
def ndim(self):
if self._is_scalar:
return 0
return self._dataset_object.ndim

@property
def attrs(self): # type: ignore
return LindiH5pyAttributes(self._dataset_object.attrs)
if isinstance(self._dataset_object, h5py.Dataset):
attrs_type = 'h5py'
elif isinstance(self._dataset_object, zarr.Array):
attrs_type = 'zarr'
else:
raise Exception(f'Unexpected dataset object type: {type(self._dataset_object)}')
return LindiH5pyAttributes(self._dataset_object.attrs, attrs_type=attrs_type)

def __getitem__(self, args, new_dtype=None):
ret = self._dataset_object.__getitem__(args, new_dtype)
if isinstance(self._dataset_object, LindiZarrWrapperDataset):
if isinstance(self._dataset_object, h5py.Dataset):
ret = self._dataset_object.__getitem__(args, new_dtype)
elif isinstance(self._dataset_object, zarr.Array):
if new_dtype is not None:
raise Exception("new_dtype is not supported for zarr.Array")
ret = self._get_item_for_zarr(self._dataset_object, args)
ret = _resolve_references(ret)
else:
raise Exception(f"Unexpected type: {type(self._dataset_object)}")
return ret

def _get_item_for_zarr(self, zarr_array: zarr.Array, selection: Any):
# First check whether this is an external array link
external_array_link = zarr_array.attrs.get("_EXTERNAL_ARRAY_LINK", None)
if external_array_link and isinstance(external_array_link, dict):
link_type = external_array_link.get("link_type", None)
if link_type == 'hdf5_dataset':
url = external_array_link.get("url", None)
name = external_array_link.get("name", None)
if url is not None and name is not None:
client = self._get_external_hdf5_client(url)
dataset = client[name]
assert isinstance(dataset, h5py.Dataset)
return dataset[selection]
if self._compound_dtype is not None:
# Compound dtype
# In this case we index into the compound dtype using the name of the field
# For example, if the dtype is [('x', 'f4'), ('y', 'f4')], then we can do
# dataset['x'][0] to get the first x value
assert self._compound_dtype.names is not None
if isinstance(selection, str):
# Find the index of this field in the compound dtype
ind = self._compound_dtype.names.index(selection)
# Get the dtype of this field
dt = self._compound_dtype[ind]
if dt == 'object':
dtype = h5py.Reference
else:
dtype = np.dtype(dt)
# Return a new object that can be sliced further
# It's important that the return type is Any here, because otherwise we get linter problems
ret: Any = LindiH5pyDatasetCompoundFieldSelection(
dataset=self, ind=ind, dtype=dtype
)
return ret
else:
raise TypeError(
f"Compound dataset {self.name} does not support selection with {selection}"
)

# We use zarr's slicing, except in the case of a scalar dataset
if self.ndim == 0:
# make sure selection is ()
if selection != ():
raise TypeError(f'Cannot slice a scalar dataset with {selection}')
return zarr_array[0]
return zarr_array[selection]

def _get_external_hdf5_client(self, url: str) -> h5py.File:
if url not in self._external_hdf5_clients:
remf = remfile.File(url)
self._external_hdf5_clients[url] = h5py.File(remf, "r")
return self._external_hdf5_clients[url]


def _resolve_references(x: Any):
if isinstance(x, dict):
if '_REFERENCE' in x:
return LindiH5pyReference(LindiZarrWrapperReference(x['_REFERENCE']))
return LindiH5pyReference(x['_REFERENCE'])
else:
for k, v in x.items():
x[k] = _resolve_references(v)
elif isinstance(x, LindiZarrWrapperReference):
return LindiH5pyReference(x)
elif isinstance(x, list):
for i, v in enumerate(x):
x[i] = _resolve_references(v)
Expand All @@ -89,3 +185,66 @@ def _resolve_references(x: Any):
for i in range(len(view_1d)):
view_1d[i] = _resolve_references(view_1d[i])
return x


class LindiH5pyDatasetCompoundFieldSelection:
"""
This class is returned when a compound dataset is indexed with a field name.
For example, if the dataset has dtype [('x', 'f4'), ('y', 'f4')], then we
can do dataset['x'][0] to get the first x value. The dataset['x'] returns an
object of this class.
"""
def __init__(self, *, dataset: LindiH5pyDataset, ind: int, dtype: np.dtype):
self._dataset = dataset # The parent dataset
self._ind = ind # The index of the field in the compound dtype
self._dtype = dtype # The dtype of the field
if self._dataset.ndim != 1:
# For now we only support 1D datasets
raise TypeError(
f"Compound field selection only implemented for 1D datasets, not {self._dataset.ndim}D"
)
if not isinstance(self._dataset._dataset_object, zarr.Array):
raise TypeError(
f"Compound field selection only implemented for zarr.Array, not {type(self._dataset._dataset_object)}"
)
za = self._dataset._dataset_object
self._zarr_array = za
# Prepare the data in memory
d = [za[i][self._ind] for i in range(len(za))]
if self._dtype == h5py.Reference:
# Convert to LindiH5pyReference
d = [LindiH5pyReference(x['_REFERENCE']) for x in d]
self._data = np.array(d, dtype=self._dtype)

def __len__(self):
"""We conform to h5py, which is the number of elements in the first dimension, TypeError if scalar"""
if self.ndim == 0:
raise TypeError("Scalar dataset")
return self.shape[0] # type: ignore

def __iter__(self):
"""We conform to h5py, which is: Iterate over the first axis. TypeError if scalar."""
shape = self.shape
if len(shape) == 0:
raise TypeError("Can't iterate over a scalar dataset")
for i in range(shape[0]):
yield self[i]

@property
def ndim(self):
return self._zarr_array.ndim

@property
def shape(self):
return self._zarr_array.shape

@property
def dtype(self):
self._dtype

@property
def size(self):
return self._data.size

def __getitem__(self, selection):
return self._data[selection]
Loading

0 comments on commit f1ce861

Please sign in to comment.