diff --git a/lindi/LindiH5pyFile/LindiH5pyAttributes.py b/lindi/LindiH5pyFile/LindiH5pyAttributes.py index 31fa625..59b97b7 100644 --- a/lindi/LindiH5pyFile/LindiH5pyAttributes.py +++ b/lindi/LindiH5pyFile/LindiH5pyAttributes.py @@ -1,28 +1,81 @@ -from typing import Union, Any +from typing import Literal from .LindiH5pyReference import LindiH5pyReference -from ..LindiZarrWrapper import LindiZarrWrapperAttributes, LindiZarrWrapperReference + +_special_attribute_keys = [ + "_SCALAR", + "_COMPOUND_DTYPE", + "_REFERENCE", + "_EXTERNAL_ARRAY_LINK", + "_SOFT_LINK", +] class LindiH5pyAttributes: - def __init__(self, attrs: Union[Any, LindiZarrWrapperAttributes]): + def __init__(self, attrs, attrs_type: Literal["h5py", "zarr"]): self._attrs = attrs + self._attrs_type = attrs_type + + def get(self, key, default=None): + if self._attrs_type == "h5py": + return self._attrs.get(key, default) + elif self._attrs_type == "zarr": + try: + if key in _special_attribute_keys: + raise KeyError + return self[key] + except KeyError: + return default + else: + raise ValueError(f"Unknown attrs_type: {self._attrs_type}") + + def __contains__(self, key): + if self._attrs_type == "h5py": + return key in self._attrs + elif self._attrs_type == "zarr": + return key in self._attrs + else: + raise ValueError(f"Unknown attrs_type: {self._attrs_type}") def __getitem__(self, key): val = self._attrs[key] - if isinstance(val, LindiZarrWrapperReference): - return LindiH5pyReference(val) - return val + if self._attrs_type == "h5py": + return val + elif self._attrs_type == "zarr": + if isinstance(val, dict) and "_REFERENCE" in val: + return LindiH5pyReference(val["_REFERENCE"]) + else: + return val + else: + raise ValueError(f"Unknown attrs_type: {self._attrs_type}") - def get(self, key, default=None): - try: - return self[key] - except KeyError: - return default + def __setitem__(self, key, value): + raise KeyError("Cannot set attributes on read-only object") + + def __delitem__(self, key): + raise KeyError("Cannot delete attributes on read-only object") + + def __iter__(self): + if self._attrs_type == "h5py": + return self._attrs.__iter__() + elif self._attrs_type == "zarr": + for k in self._attrs: + if k not in _special_attribute_keys: + yield k + else: + raise ValueError(f"Unknown attrs_type: {self._attrs_type}") def items(self): for k in self: yield k, self[k] - def __iter__(self): - for k in self._attrs: - yield k + def __len__(self): + ct = 0 + for _ in self: + ct += 1 + return ct + + def __repr__(self): + return repr(self._attrs) + + def __str__(self): + return str(self._attrs) diff --git a/lindi/LindiH5pyFile/LindiH5pyDataset.py b/lindi/LindiH5pyFile/LindiH5pyDataset.py index 27afe93..3962396 100644 --- a/lindi/LindiH5pyFile/LindiH5pyDataset.py +++ b/lindi/LindiH5pyFile/LindiH5pyDataset.py @@ -1,11 +1,11 @@ -from typing import TYPE_CHECKING, Union, Any -import h5py +from typing import TYPE_CHECKING, Union, Any, Dict import numpy as np +import h5py +import zarr +import remfile from .LindiH5pyAttributes import LindiH5pyAttributes from .LindiH5pyReference import LindiH5pyReference -from ..LindiZarrWrapper import LindiZarrWrapperDataset -from ..LindiZarrWrapper import LindiZarrWrapperReference if TYPE_CHECKING: @@ -18,24 +18,55 @@ def __init__(self, _h5py_dataset_id): class LindiH5pyDataset(h5py.Dataset): - def __init__(self, _dataset_object: Union[h5py.Dataset, LindiZarrWrapperDataset], _file: "LindiH5pyFile"): + def __init__(self, _dataset_object: Union[h5py.Dataset, zarr.Array], _file: "LindiH5pyFile"): self._dataset_object = _dataset_object self._file = _file + self._external_hdf5_clients: Dict[str, h5py.File] = {} + + # See if we have the _COMPOUND_DTYPE attribute, which signifies that + # this is a compound dtype + if isinstance(_dataset_object, zarr.Array): + compound_dtype_obj = _dataset_object.attrs.get("_COMPOUND_DTYPE", None) + if compound_dtype_obj is not None: + # If we have a compound dtype, then create the numpy dtype + self._compound_dtype = np.dtype( + [(compound_dtype_obj[i][0], compound_dtype_obj[i][1]) for i in range(len(compound_dtype_obj))] + ) + else: + self._compound_dtype = None + else: + self._compound_dtype = None + + # Check whether this is a scalar dataset + if isinstance(_dataset_object, zarr.Array): + self._is_scalar = self._dataset_object.attrs.get("_SCALAR", False) + else: + self._is_scalar = self._dataset_object.ndim == 0 + @property def id(self): - return LindiH5pyDatasetId(self._dataset_object.id) + if isinstance(self._dataset_object, h5py.Dataset): + return LindiH5pyDatasetId(self._dataset_object.id) + else: + return LindiH5pyDatasetId(None) @property def shape(self): # type: ignore + if self._is_scalar: + return () return self._dataset_object.shape @property def size(self): + if self._is_scalar: + return 1 return self._dataset_object.size @property def dtype(self): + if self._compound_dtype is not None: + return self._compound_dtype return self._dataset_object.dtype @property @@ -58,28 +89,93 @@ def maxshape(self): @property def ndim(self): + if self._is_scalar: + return 0 return self._dataset_object.ndim @property def attrs(self): # type: ignore - return LindiH5pyAttributes(self._dataset_object.attrs) + if isinstance(self._dataset_object, h5py.Dataset): + attrs_type = 'h5py' + elif isinstance(self._dataset_object, zarr.Array): + attrs_type = 'zarr' + else: + raise Exception(f'Unexpected dataset object type: {type(self._dataset_object)}') + return LindiH5pyAttributes(self._dataset_object.attrs, attrs_type=attrs_type) def __getitem__(self, args, new_dtype=None): - ret = self._dataset_object.__getitem__(args, new_dtype) - if isinstance(self._dataset_object, LindiZarrWrapperDataset): + if isinstance(self._dataset_object, h5py.Dataset): + ret = self._dataset_object.__getitem__(args, new_dtype) + elif isinstance(self._dataset_object, zarr.Array): + if new_dtype is not None: + raise Exception("new_dtype is not supported for zarr.Array") + ret = self._get_item_for_zarr(self._dataset_object, args) ret = _resolve_references(ret) + else: + raise Exception(f"Unexpected type: {type(self._dataset_object)}") return ret + def _get_item_for_zarr(self, zarr_array: zarr.Array, selection: Any): + # First check whether this is an external array link + external_array_link = zarr_array.attrs.get("_EXTERNAL_ARRAY_LINK", None) + if external_array_link and isinstance(external_array_link, dict): + link_type = external_array_link.get("link_type", None) + if link_type == 'hdf5_dataset': + url = external_array_link.get("url", None) + name = external_array_link.get("name", None) + if url is not None and name is not None: + client = self._get_external_hdf5_client(url) + dataset = client[name] + assert isinstance(dataset, h5py.Dataset) + return dataset[selection] + if self._compound_dtype is not None: + # Compound dtype + # In this case we index into the compound dtype using the name of the field + # For example, if the dtype is [('x', 'f4'), ('y', 'f4')], then we can do + # dataset['x'][0] to get the first x value + assert self._compound_dtype.names is not None + if isinstance(selection, str): + # Find the index of this field in the compound dtype + ind = self._compound_dtype.names.index(selection) + # Get the dtype of this field + dt = self._compound_dtype[ind] + if dt == 'object': + dtype = h5py.Reference + else: + dtype = np.dtype(dt) + # Return a new object that can be sliced further + # It's important that the return type is Any here, because otherwise we get linter problems + ret: Any = LindiH5pyDatasetCompoundFieldSelection( + dataset=self, ind=ind, dtype=dtype + ) + return ret + else: + raise TypeError( + f"Compound dataset {self.name} does not support selection with {selection}" + ) + + # We use zarr's slicing, except in the case of a scalar dataset + if self.ndim == 0: + # make sure selection is () + if selection != (): + raise TypeError(f'Cannot slice a scalar dataset with {selection}') + return zarr_array[0] + return zarr_array[selection] + + def _get_external_hdf5_client(self, url: str) -> h5py.File: + if url not in self._external_hdf5_clients: + remf = remfile.File(url) + self._external_hdf5_clients[url] = h5py.File(remf, "r") + return self._external_hdf5_clients[url] + def _resolve_references(x: Any): if isinstance(x, dict): if '_REFERENCE' in x: - return LindiH5pyReference(LindiZarrWrapperReference(x['_REFERENCE'])) + return LindiH5pyReference(x['_REFERENCE']) else: for k, v in x.items(): x[k] = _resolve_references(v) - elif isinstance(x, LindiZarrWrapperReference): - return LindiH5pyReference(x) elif isinstance(x, list): for i, v in enumerate(x): x[i] = _resolve_references(v) @@ -89,3 +185,66 @@ def _resolve_references(x: Any): for i in range(len(view_1d)): view_1d[i] = _resolve_references(view_1d[i]) return x + + +class LindiH5pyDatasetCompoundFieldSelection: + """ + This class is returned when a compound dataset is indexed with a field name. + For example, if the dataset has dtype [('x', 'f4'), ('y', 'f4')], then we + can do dataset['x'][0] to get the first x value. The dataset['x'] returns an + object of this class. + """ + def __init__(self, *, dataset: LindiH5pyDataset, ind: int, dtype: np.dtype): + self._dataset = dataset # The parent dataset + self._ind = ind # The index of the field in the compound dtype + self._dtype = dtype # The dtype of the field + if self._dataset.ndim != 1: + # For now we only support 1D datasets + raise TypeError( + f"Compound field selection only implemented for 1D datasets, not {self._dataset.ndim}D" + ) + if not isinstance(self._dataset._dataset_object, zarr.Array): + raise TypeError( + f"Compound field selection only implemented for zarr.Array, not {type(self._dataset._dataset_object)}" + ) + za = self._dataset._dataset_object + self._zarr_array = za + # Prepare the data in memory + d = [za[i][self._ind] for i in range(len(za))] + if self._dtype == h5py.Reference: + # Convert to LindiH5pyReference + d = [LindiH5pyReference(x['_REFERENCE']) for x in d] + self._data = np.array(d, dtype=self._dtype) + + def __len__(self): + """We conform to h5py, which is the number of elements in the first dimension, TypeError if scalar""" + if self.ndim == 0: + raise TypeError("Scalar dataset") + return self.shape[0] # type: ignore + + def __iter__(self): + """We conform to h5py, which is: Iterate over the first axis. TypeError if scalar.""" + shape = self.shape + if len(shape) == 0: + raise TypeError("Can't iterate over a scalar dataset") + for i in range(shape[0]): + yield self[i] + + @property + def ndim(self): + return self._zarr_array.ndim + + @property + def shape(self): + return self._zarr_array.shape + + @property + def dtype(self): + self._dtype + + @property + def size(self): + return self._data.size + + def __getitem__(self, selection): + return self._data[selection] diff --git a/lindi/LindiH5pyFile/LindiH5pyFile.py b/lindi/LindiH5pyFile/LindiH5pyFile.py index 2ddf0d0..8c1f4f9 100644 --- a/lindi/LindiH5pyFile/LindiH5pyFile.py +++ b/lindi/LindiH5pyFile/LindiH5pyFile.py @@ -1,38 +1,67 @@ from typing import Union +import json +import tempfile +import urllib.request import h5py import zarr +from zarr.storage import Store as ZarrStore +from fsspec.implementations.reference import ReferenceFileSystem +from fsspec import FSMap from .LindiH5pyGroup import LindiH5pyGroup from .LindiH5pyDataset import LindiH5pyDataset -from ..LindiZarrWrapper import LindiZarrWrapper, LindiZarrWrapperGroup, LindiZarrWrapperDataset, LindiZarrWrapperReference from .LindiH5pyAttributes import LindiH5pyAttributes from .LindiH5pyReference import LindiH5pyReference class LindiH5pyFile(h5py.File): - def __init__(self, _file_object: Union[h5py.File, LindiZarrWrapper]): + def __init__(self, _file_object: Union[h5py.File, zarr.Group]): """ - Do not use this constructor directly. Instead, use - from_reference_file_system, from_zarr_group, or from_h5py_file. + Do not use this constructor directly. Instead, use: + from_reference_file_system, from_zarr_store, from_zarr_group, + or from_h5py_file """ self._file_object = _file_object self._the_group = LindiH5pyGroup(_file_object, self) @staticmethod - def from_reference_file_system(rfs: dict): + def from_reference_file_system(rfs: Union[dict, str]): """ Create a LindiH5pyFile from a reference file system. """ - x = LindiZarrWrapper.from_reference_file_system(rfs) - return LindiH5pyFile(x) + if isinstance(rfs, str): + if rfs.startswith("http") or rfs.startswith("https"): + with tempfile.TemporaryDirectory() as tmpdir: + filename = f"{tmpdir}/temp.zarr.json" + _download_file(rfs, filename) + with open(filename, "r") as f: + data = json.load(f) + assert isinstance(data, dict) # prevent infinite recursion + return LindiH5pyFile.from_reference_file_system(data) + else: + with open(rfs, "r") as f: + data = json.load(f) + assert isinstance(data, dict) # prevent infinite recursion + return LindiH5pyFile.from_reference_file_system(data) + else: + fs = ReferenceFileSystem(rfs).get_mapper(root="") + return LindiH5pyFile.from_zarr_store(fs) + + @staticmethod + def from_zarr_store(zarr_store: Union[ZarrStore, FSMap]): + """ + Create a LindiH5pyFile from a zarr store. + """ + zarr_group = zarr.open(store=zarr_store, mode="r") + assert isinstance(zarr_group, zarr.Group) + return LindiH5pyFile.from_zarr_group(zarr_group) @staticmethod def from_zarr_group(zarr_group: zarr.Group): """ Create a LindiH5pyFile from a zarr group. """ - x = LindiZarrWrapper.from_zarr_group(zarr_group) - return LindiH5pyFile(x) + return LindiH5pyFile(zarr_group) @staticmethod def from_h5py_file(h5py_file: h5py.File): @@ -43,12 +72,23 @@ def from_h5py_file(h5py_file: h5py.File): @property def attrs(self): # type: ignore - return LindiH5pyAttributes(self._file_object.attrs) + if isinstance(self._file_object, h5py.File): + attrs_type = 'h5py' + elif isinstance(self._file_object, zarr.Group): + attrs_type = 'zarr' + else: + raise Exception(f'Unexpected file object type: {type(self._file_object)}') + return LindiH5pyAttributes(self._file_object.attrs, attrs_type=attrs_type) @property def filename(self): # This is not a string, but this is what h5py seems to do - return self._file_object.filename + if isinstance(self._file_object, h5py.File): + return self._file_object.filename + elif isinstance(self._file_object, zarr.Group): + return '' + else: + raise Exception(f"Unhandled type for file object: {type(self._file_object)}") @property def driver(self): @@ -58,7 +98,7 @@ def driver(self): def mode(self): if isinstance(self._file_object, h5py.File): return self._file_object.mode - elif isinstance(self._file_object, LindiZarrWrapper): + elif isinstance(self._file_object, zarr.Group): # hard-coded to read-only return "r" else: @@ -94,26 +134,30 @@ def __exit__(self, *args): self.close() def __repr__(self): - return f'' + return f'' # Group methods - - def __getitem__(self, name): - if isinstance(name, LindiZarrWrapperReference): - # annoyingly we have to do this because references - # in arrays of compound types will come in as LindiZarrWrapperReference - name = LindiH5pyReference(name) - if isinstance(name, LindiH5pyReference): - assert isinstance(self._file_object, LindiZarrWrapper) - x = self._file_object[name._reference] - if isinstance(x, LindiZarrWrapperGroup): - return LindiH5pyGroup(x, self) - elif isinstance(x, LindiZarrWrapperDataset): - return LindiH5pyDataset(x, self) - else: - raise Exception(f"Unexpected type for resolved reference at path {name}: {type(x)}") - elif isinstance(name, h5py.Reference): - assert isinstance(self._file_object, h5py.File) + def __getitem__(self, name): # type: ignore + return self._get_item(name) + + def _get_item(self, name, getlink=False, default=None): + if isinstance(name, LindiH5pyReference) and isinstance(self._file_object, zarr.Group): + if getlink: + raise Exception("Getting link is not allowed for references") + zarr_group = self._file_object + if name._source != '.': + raise Exception(f'For now, source of reference must be ".", got "{name._source}"') + if name._source_object_id is not None: + if name._source_object_id != zarr_group.attrs.get("object_id"): + raise Exception(f'Mismatch in source object_id: "{name._source_object_id}" and "{zarr_group.attrs.get("object_id")}"') + target = self[name._path] + if name._object_id is not None: + if name._object_id != target.attrs.get("object_id"): + raise Exception(f'Mismatch in object_id: "{name._object_id}" and "{target.attrs.get("object_id")}"') + return target + elif isinstance(name, h5py.Reference) and isinstance(self._file_object, h5py.File): + if getlink: + raise Exception("Getting link is not allowed for references") x = self._file_object[name] if isinstance(x, h5py.Group): return LindiH5pyGroup(x, self) @@ -125,13 +169,20 @@ def __getitem__(self, name): if isinstance(name, str) and "/" in name: parts = name.split("/") x = self._the_group - for part in parts: - x = x[part] + for i, part in enumerate(parts): + if i == len(parts) - 1: + assert isinstance(x, LindiH5pyGroup) + x = x.get(part, default=default, getlink=getlink) + else: + assert isinstance(x, LindiH5pyGroup) + x = x.get(part) return x - return self._the_group[name] + return self._the_group.get(name, default=default, getlink=getlink) def get(self, name, default=None, getclass=False, getlink=False): - return self._the_group.get(name, default, getclass, getlink) + if getclass: + raise Exception("Getting class is not allowed") + return self._get_item(name, getlink=getlink, default=default) def __iter__(self): return self._the_group.__iter__() @@ -153,3 +204,13 @@ def file(self): @property def name(self): return self._the_group.name + + +def _download_file(url: str, filename: str) -> None: + headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3" + } + req = urllib.request.Request(url, headers=headers) + with urllib.request.urlopen(req) as response: + with open(filename, "wb") as f: + f.write(response.read()) diff --git a/lindi/LindiH5pyFile/LindiH5pyGroup.py b/lindi/LindiH5pyFile/LindiH5pyGroup.py index 021f9d8..8678f71 100644 --- a/lindi/LindiH5pyFile/LindiH5pyGroup.py +++ b/lindi/LindiH5pyFile/LindiH5pyGroup.py @@ -1,10 +1,9 @@ from typing import TYPE_CHECKING, Union import h5py +import zarr -from lindi.LindiZarrWrapper import LindiZarrWrapperDataset from .LindiH5pyDataset import LindiH5pyDataset from .LindiH5pyLink import LindiH5pyHardLink, LindiH5pySoftLink -from ..LindiZarrWrapper import LindiZarrWrapperGroup from .LindiH5pyAttributes import LindiH5pyAttributes @@ -18,7 +17,7 @@ def __init__(self, _h5py_group_id): class LindiH5pyGroup(h5py.Group): - def __init__(self, _group_object: Union[h5py.Group, LindiZarrWrapperGroup], _file: "LindiH5pyFile"): + def __init__(self, _group_object: Union[h5py.Group, zarr.Group], _file: "LindiH5pyFile"): self._group_object = _group_object self._file = _file @@ -37,7 +36,7 @@ def __getitem__(self, name): return LindiH5pyDataset(x, self._file) else: raise Exception(f"Unknown type: {type(x)}") - elif isinstance(self._group_object, LindiZarrWrapperGroup): + elif isinstance(self._group_object, zarr.Group): if isinstance(name, (bytes, str)): x = self._group_object[name] else: @@ -45,10 +44,11 @@ def __getitem__(self, name): "Accessing a group is done with bytes or str, " "not {}".format(type(name)) ) - if isinstance(x, LindiZarrWrapperGroup): + if isinstance(x, zarr.Group): # follow the link if this is a soft link - if x.soft_link is not None: - link_path = x.soft_link['path'] + soft_link = x.attrs.get('_SOFT_LINK', None) + if soft_link is not None: + link_path = soft_link['path'] target_grp = self._file.get(link_path) if not isinstance(target_grp, LindiH5pyGroup): raise Exception( @@ -56,7 +56,7 @@ def __getitem__(self, name): ) return target_grp return LindiH5pyGroup(x, self._file) - elif isinstance(x, LindiZarrWrapperDataset): + elif isinstance(x, zarr.Array): return LindiH5pyDataset(x, self._file) else: raise Exception(f"Unknown type: {type(x)}") @@ -85,10 +85,13 @@ def get(self, name, default=None, getclass=False, getlink=False): raise Exception( f"Unhandled type for get with getlink at {self.name} {name}: {type(x)}" ) - elif isinstance(self._group_object, LindiZarrWrapperGroup): + elif isinstance(self._group_object, zarr.Group): x = self._group_object.get(name, default=default) - if isinstance(x, LindiZarrWrapperGroup) and x.soft_link is not None: - return LindiH5pySoftLink(x.soft_link['path']) + if x is None: + return default + soft_link = x.attrs.get('_SOFT_LINK', None) + if isinstance(x, zarr.Group) and soft_link is not None: + return LindiH5pySoftLink(soft_link['path']) else: return LindiH5pyHardLink() else: @@ -100,18 +103,26 @@ def get(self, name, default=None, getclass=False, getlink=False): def name(self): return self._group_object.name + def keys(self): # type: ignore + return self._group_object.keys() + def __iter__(self): return self._group_object.__iter__() def __reversed__(self): - return self._group_object.__reversed__() + raise Exception("Not implemented: __reversed__") def __contains__(self, name): return self._group_object.__contains__(name) @property def id(self): - return LindiH5pyGroupId(self._group_object.id) + if isinstance(self._group_object, h5py.Group): + return LindiH5pyGroupId(self._group_object.id) + elif isinstance(self._group_object, zarr.Group): + return LindiH5pyGroupId(None) + else: + raise Exception(f'Unexpected group object type: {type(self._group_object)}') @property def file(self): @@ -119,4 +130,10 @@ def file(self): @property def attrs(self): # type: ignore - return LindiH5pyAttributes(self._group_object.attrs) + if isinstance(self._group_object, h5py.Group): + attrs_type = 'h5py' + elif isinstance(self._group_object, zarr.Group): + attrs_type = 'zarr' + else: + raise Exception(f'Unexpected group object type: {type(self._group_object)}') + return LindiH5pyAttributes(self._group_object.attrs, attrs_type=attrs_type) diff --git a/lindi/LindiH5pyFile/LindiH5pyReference.py b/lindi/LindiH5pyFile/LindiH5pyReference.py index 491d79f..e1bda22 100644 --- a/lindi/LindiH5pyFile/LindiH5pyReference.py +++ b/lindi/LindiH5pyFile/LindiH5pyReference.py @@ -1,13 +1,15 @@ import h5py -from ..LindiZarrWrapper import LindiZarrWrapperReference class LindiH5pyReference(h5py.h5r.Reference): - def __init__(self, reference: LindiZarrWrapperReference): - self._reference = reference + def __init__(self, obj: dict): + self._object_id = obj["object_id"] + self._path = obj["path"] + self._source = obj["source"] + self._source_object_id = obj["source_object_id"] def __repr__(self): - return f"LindiH5pyReference({self._reference})" + return f"LindiH5pyReference({self._object_id}, {self._path})" def __str__(self): - return f"LindiH5pyReference({self._reference})" + return f"LindiH5pyReference({self._object_id}, {self._path})" diff --git a/lindi/LindiZarrWrapper/LindiZarrWrapper.py b/lindi/LindiZarrWrapper/LindiZarrWrapper.py deleted file mode 100644 index 29fabd5..0000000 --- a/lindi/LindiZarrWrapper/LindiZarrWrapper.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Union -import json -import tempfile -from typing import Literal -from fsspec import FSMap -import zarr -import urllib.request -from fsspec.implementations.reference import ReferenceFileSystem -from zarr.storage import Store -from .LindiZarrWrapperGroup import LindiZarrWrapperGroup -from .LindiZarrWrapperReference import LindiZarrWrapperReference - - -class LindiZarrWrapper(LindiZarrWrapperGroup): - def __init__( - self, - *, - _zarr_group: zarr.Group, - ) -> None: - self._zarr_group = _zarr_group - super().__init__(_zarr_group=self._zarr_group, _client=self) - - @property - def filename(self): - return '' - - @staticmethod - def from_zarr_store(zarr_store: Union[Store, FSMap]) -> "LindiZarrWrapper": - zarr_group = zarr.open(store=zarr_store, mode="r") - assert isinstance(zarr_group, zarr.Group) - return LindiZarrWrapper.from_zarr_group(zarr_group) - - @staticmethod - def from_file( - json_file: str, file_type: Literal["zarr.json"] = "zarr.json" - ) -> "LindiZarrWrapper": - if file_type == "zarr.json": - if json_file.startswith("http") or json_file.startswith("https"): - with tempfile.TemporaryDirectory() as tmpdir: - filename = f"{tmpdir}/temp.zarr.json" - _download_file(json_file, filename) - with open(filename, "r") as f: - data = json.load(f) - return LindiZarrWrapper.from_reference_file_system(data) - else: - with open(json_file, "r") as f: - data = json.load(f) - return LindiZarrWrapper.from_reference_file_system(data) - else: - raise ValueError(f"Unknown file_type: {file_type}") - - @staticmethod - def from_zarr_group(zarr_group: zarr.Group) -> "LindiZarrWrapper": - return LindiZarrWrapper(_zarr_group=zarr_group) - - @staticmethod - def from_reference_file_system(data: dict) -> "LindiZarrWrapper": - fs = ReferenceFileSystem(data).get_mapper(root="") - return LindiZarrWrapper.from_zarr_store(fs) - - def get(self, key, default=None): - try: - ret = self[key] - except KeyError: - ret = default - return ret - - def __getitem__(self, key): # type: ignore - if isinstance(key, str): - if key.startswith('/'): - key = key[1:] - parts = key.split("/") - if len(parts) == 1: - return super().__getitem__(key) - else: - g = self - for part in parts: - g = g[part] - return g - elif isinstance(key, LindiZarrWrapperReference): - if key._source != '.': - raise Exception(f'For now, source of reference must be ".", got "{key._source}"') - if key._source_object_id is not None: - if key._source_object_id != self._zarr_group.attrs.get("object_id"): - raise Exception(f'Mismatch in source object_id: "{key._source_object_id}" and "{self._zarr_group.attrs.get("object_id")}"') - target = self[key._path] - if key._object_id is not None: - if key._object_id != target.attrs.get("object_id"): - raise Exception(f'Mismatch in object_id: "{key._object_id}" and "{target.attrs.get("object_id")}"') - return target - else: - raise Exception(f'Cannot use key "{key}" of type "{type(key)}" to index into a LindiZarrWrapper') - - -def _download_file(url: str, filename: str) -> None: - headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3" - } - req = urllib.request.Request(url, headers=headers) - with urllib.request.urlopen(req) as response: - with open(filename, "wb") as f: - f.write(response.read()) diff --git a/lindi/LindiZarrWrapper/LindiZarrWrapperAttributes.py b/lindi/LindiZarrWrapper/LindiZarrWrapperAttributes.py deleted file mode 100644 index c80a41a..0000000 --- a/lindi/LindiZarrWrapper/LindiZarrWrapperAttributes.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Union -import zarr -from .LindiZarrWrapperReference import LindiZarrWrapperReference - - -_special_attribute_keys = ["_SCALAR", "_COMPOUND_DTYPE", "_REFERENCE", "_EXTERNAL_ARRAY_LINK", "_SOFT_LINK"] - - -class LindiZarrWrapperAttributes: - def __init__(self, *, _object: Union[zarr.Group, zarr.Array]): - self._object = _object - - def get(self, key, default=None): - try: - if key in _special_attribute_keys: - raise KeyError - return self[key] - except KeyError: - return default - - def __getitem__(self, key): - val = self._object.attrs[key] - if isinstance(val, dict) and "_REFERENCE" in val: - return LindiZarrWrapperReference(val["_REFERENCE"]) - return self._object.attrs[key] - - def __setitem__(self, key, value): - raise KeyError("Cannot set attributes on read-only object") - - def __delitem__(self, key): - raise KeyError("Cannot delete attributes on read-only object") - - def __iter__(self): - for k in self._object.attrs: - if k not in _special_attribute_keys: - yield k - - def items(self): - for k in self: - yield k, self[k] - - def __len__(self): - ct = 0 - for _ in self: - ct += 1 - return ct - - def __repr__(self): - return repr(self._object.attrs) - - def __str__(self): - return str(self._object.attrs) diff --git a/lindi/LindiZarrWrapper/LindiZarrWrapperDataset.py b/lindi/LindiZarrWrapper/LindiZarrWrapperDataset.py deleted file mode 100644 index 698b16d..0000000 --- a/lindi/LindiZarrWrapper/LindiZarrWrapperDataset.py +++ /dev/null @@ -1,192 +0,0 @@ -from typing import Dict, Any -import numpy as np -import zarr -import h5py -import remfile -from .LindiZarrWrapperAttributes import LindiZarrWrapperAttributes -from .LindiZarrWrapperReference import LindiZarrWrapperReference - - -class LindiZarrWrapperDataset: - def __init__(self, *, _zarr_array: zarr.Array, _client): - self._zarr_array = _zarr_array - self._is_scalar = self._zarr_array.attrs.get("_SCALAR", False) - self._client = _client - - # See if we have the _COMPOUND_DTYPE attribute, which signifies that - # this is a compound dtype - compound_dtype_obj = self._zarr_array.attrs.get("_COMPOUND_DTYPE", None) - if compound_dtype_obj is not None: - # If we have a compound dtype, then create the numpy dtype - self._compound_dtype = np.dtype( - [(compound_dtype_obj[i][0], compound_dtype_obj[i][1]) for i in range(len(compound_dtype_obj))] - ) - else: - self._compound_dtype = None - - self._external_hdf5_clients: Dict[str, h5py.File] = {} - - @property - def file(self): - return self._client - - @property - def id(self): - return None - - @property - def name(self): - return self._zarr_array.name - - @property - def attrs(self): - """Attributes attached to this object""" - return LindiZarrWrapperAttributes(_object=self._zarr_array) - - @property - def ndim(self): - if self._is_scalar: - return 0 - return self._zarr_array.ndim - - @property - def shape(self): - if self._is_scalar: - return () - return self._zarr_array.shape - - @property - def dtype(self): - if self._compound_dtype is not None: - return self._compound_dtype - return self._zarr_array.dtype - - @property - def size(self): - if self._is_scalar: - return 1 - return self._zarr_array.size - - @property - def nbytes(self): - return self._zarr_array.nbytes - - def __len__(self): - """We conform to h5py, which is the number of elements in the first dimension, TypeError if scalar""" - if self.ndim == 0: - raise TypeError("Scalar dataset") - return self.shape[0] # type: ignore - - def __iter__(self): - """We conform to h5py, which is: Iterate over the first axis. TypeError if scalar.""" - shape = self.shape - if len(shape) == 0: - raise TypeError("Can't iterate over a scalar dataset") - for i in range(shape[0]): - yield self[i] - - def __getitem__(self, selection, new_dtype=None): - if new_dtype is not None: - raise TypeError("new_dtype not supported in LindiZarrWrapperDataset.__getitem__") - # First check whether this is an external array link - external_array_link = self._zarr_array.attrs.get("_EXTERNAL_ARRAY_LINK", None) - if external_array_link and isinstance(external_array_link, dict): - link_type = external_array_link.get("link_type", None) - if link_type == 'hdf5_dataset': - url = external_array_link.get("url", None) - name = external_array_link.get("name", None) - if url is not None and name is not None: - client = self._get_external_hdf5_client(url) - dataset = client[name] - assert isinstance(dataset, h5py.Dataset) - return dataset[selection] - if self._compound_dtype is not None: - # Compound dtype - # In this case we index into the compound dtype using the name of the field - # For example, if the dtype is [('x', 'f4'), ('y', 'f4')], then we can do - # dataset['x'][0] to get the first x value - assert self._compound_dtype.names is not None - if isinstance(selection, str): - # Find the index of this field in the compound dtype - ind = self._compound_dtype.names.index(selection) - # Get the dtype of this field - dt = self._compound_dtype[ind] - if dt == 'object': - dtype = h5py.Reference - else: - dtype = np.dtype(dt) - # Return a new object that can be sliced further - # It's important that the return type is Any here, because otherwise we get linter problems - ret: Any = LindiZarrWrapperDatasetCompoundFieldSelection( - dataset=self, ind=ind, dtype=dtype - ) - return ret - else: - raise TypeError( - f"Compound dataset {self.name} does not support selection with {selection}" - ) - - # We use zarr's slicing, except in the case of a scalar dataset - if self.ndim == 0: - # make sure selection is () - if selection != (): - raise TypeError(f'Cannot slice a scalar dataset with {selection}') - return self._zarr_array[0] - return self._zarr_array[selection] - - def _get_external_hdf5_client(self, url: str) -> h5py.File: - if url not in self._external_hdf5_clients: - remf = remfile.File(url) - self._external_hdf5_clients[url] = h5py.File(remf, "r") - return self._external_hdf5_clients[url] - - -class LindiZarrWrapperDatasetCompoundFieldSelection: - """ - This class is returned when a compound dataset is indexed with a field name. - For example, if the dataset has dtype [('x', 'f4'), ('y', 'f4')], then we - can do dataset['x'][0] to get the first x value. The dataset['x'] returns an - object of this class. - """ - def __init__(self, *, dataset: LindiZarrWrapperDataset, ind: int, dtype: np.dtype): - self._dataset = dataset # The parent dataset - self._ind = ind # The index of the field in the compound dtype - self._dtype = dtype # The dtype of the field - if self._dataset.ndim != 1: - # For now we only support 1D datasets - raise TypeError( - f"Compound field selection only implemented for 1D datasets, not {self._dataset.ndim}D" - ) - # Prepare the data in memory - za = self._dataset._zarr_array - d = [za[i][self._ind] for i in range(len(za))] - if self._dtype == h5py.Reference: - # Convert to LindiZarrWrapperReference - d = [LindiZarrWrapperReference(x['_REFERENCE']) for x in d] - self._data = np.array(d, dtype=self._dtype) - - def __len__(self): - return self._dataset._zarr_array.shape[0] - - def __iter__(self): - for i in range(len(self)): - yield self[i] - - @property - def ndim(self): - return self._dataset._zarr_array.ndim - - @property - def shape(self): - return self._dataset._zarr_array.shape - - @property - def dtype(self): - self._dtype - - @property - def size(self): - return self._data.size - - def __getitem__(self, selection): - return self._data[selection] diff --git a/lindi/LindiZarrWrapper/LindiZarrWrapperGroup.py b/lindi/LindiZarrWrapper/LindiZarrWrapperGroup.py deleted file mode 100644 index 696cd1f..0000000 --- a/lindi/LindiZarrWrapper/LindiZarrWrapperGroup.py +++ /dev/null @@ -1,72 +0,0 @@ -from typing import TYPE_CHECKING -import zarr -from .LindiZarrWrapperAttributes import LindiZarrWrapperAttributes -from .LindiZarrWrapperDataset import LindiZarrWrapperDataset - - -if TYPE_CHECKING: - from .LindiZarrWrapper import LindiZarrWrapper - - -class LindiZarrWrapperGroup: - def __init__(self, *, _zarr_group: zarr.Group, _client: "LindiZarrWrapper"): - self._zarr_group = _zarr_group - self._client = _client - - @property - def file(self): - return self._client - - @property - def id(self): - return None - - @property - def attrs(self): - """Attributes attached to this object""" - return LindiZarrWrapperAttributes(_object=self._zarr_group) - - def keys(self): - return self._zarr_group.keys() - - @property - def name(self): - return self._zarr_group.name - - @property - def soft_link(self): - x = self._zarr_group.attrs.get("_SOFT_LINK", None) - return x - - def get(self, key, default=None): - try: - return self[key] - except KeyError: - return default - - def __getitem__(self, key): - if not isinstance(key, str): - raise Exception( - f'Cannot use key "{key}" of type "{type(key)}" to index into a LindiZarrWrapperGroup, at path "{self._zarr_group.name}"' - ) - if key in self._zarr_group.keys(): - x = self._zarr_group[key] - if isinstance(x, zarr.Group): - return LindiZarrWrapperGroup(_zarr_group=x, _client=self._client) - elif isinstance(x, zarr.Array): - return LindiZarrWrapperDataset(_zarr_array=x, _client=self._client) - else: - raise Exception(f"Unknown type: {type(x)}") - else: - raise KeyError(f'Key "{key}" not found in group "{self._zarr_group.name}"') - - def __iter__(self): - for k in self.keys(): - yield k - - def __reversed__(self): - for k in reversed(self.keys()): - yield k - - def __contains__(self, key): - return key in self._zarr_group diff --git a/lindi/LindiZarrWrapper/LindiZarrWrapperReference.py b/lindi/LindiZarrWrapper/LindiZarrWrapperReference.py deleted file mode 100644 index 3d56069..0000000 --- a/lindi/LindiZarrWrapper/LindiZarrWrapperReference.py +++ /dev/null @@ -1,21 +0,0 @@ -import h5py - - -# We need h5py.Reference as a base class so that type checking will be okay for -# arrays of compound types that contain references -class LindiZarrWrapperReference(h5py.Reference): - def __init__(self, obj: dict): - self._object_id = obj["object_id"] - self._path = obj["path"] - self._source = obj["source"] - self._source_object_id = obj["source_object_id"] - - @property - def name(self): - return self._path - - def __repr__(self): - return f"LindiZarrWrapperReference({self._source}, {self._path})" - - def __str__(self): - return f"LindiZarrWrapperReference({self._source}, {self._path})" diff --git a/lindi/LindiZarrWrapper/__init__.py b/lindi/LindiZarrWrapper/__init__.py deleted file mode 100644 index 063f684..0000000 --- a/lindi/LindiZarrWrapper/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .LindiZarrWrapper import LindiZarrWrapper # noqa: F401 -from .LindiZarrWrapperGroup import LindiZarrWrapperGroup # noqa: F401 -from .LindiZarrWrapperDataset import LindiZarrWrapperDataset # noqa: F401 -from .LindiZarrWrapperAttributes import LindiZarrWrapperAttributes # noqa: F401 -from .LindiZarrWrapperReference import LindiZarrWrapperReference # noqa: F401 diff --git a/lindi/__init__.py b/lindi/__init__.py index e4139a4..3c9754a 100644 --- a/lindi/__init__.py +++ b/lindi/__init__.py @@ -1,3 +1,2 @@ -from .LindiZarrWrapper import LindiZarrWrapper, LindiZarrWrapperGroup, LindiZarrWrapperDataset, LindiZarrWrapperAttributes, LindiZarrWrapperReference # noqa: F401 from .LindiH5ZarrStore import LindiH5ZarrStore, LindiH5ZarrStoreOpts # noqa: F401 from .LindiH5pyFile import LindiH5pyFile, LindiH5pyGroup, LindiH5pyDataset, LindiH5pyHardLink, LindiH5pySoftLink # noqa: F401 diff --git a/tests/test_core.py b/tests/test_core.py index ac2073e..1743b17 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -2,11 +2,7 @@ import h5py import tempfile import lindi -from lindi import ( - LindiH5ZarrStore, - LindiZarrWrapper, - LindiZarrWrapperDataset -) +from lindi import LindiH5ZarrStore def test_variety(): @@ -37,15 +33,15 @@ def test_variety(): assert _lists_are_equal(h5f_2.attrs["list1"], h5f.attrs["list1"]) assert _lists_are_equal(h5f_2.attrs["tuple1"], h5f.attrs["tuple1"]) assert _arrays_are_equal(np.array(h5f_2.attrs["array1"]), h5f.attrs["array1"]) - assert h5f_2["dataset1"].attrs["test_attr1"] == h5f["dataset1"].attrs["test_attr1"] + assert h5f_2["dataset1"].attrs["test_attr1"] == h5f["dataset1"].attrs["test_attr1"] # type: ignore assert _arrays_are_equal(h5f_2["dataset1"][()], h5f["dataset1"][()]) # type: ignore - assert h5f_2["group1"].attrs["test_attr2"] == h5f["group1"].attrs["test_attr2"] + assert h5f_2["group1"].attrs["test_attr2"] == h5f["group1"].attrs["test_attr2"] # type: ignore target_1 = h5f[h5f.attrs["dataset1_ref"]] target_2 = h5f_2[h5f_2.attrs["dataset1_ref"]] - assert target_1.attrs["test_attr1"] == target_2.attrs["test_attr1"] + assert target_1.attrs["test_attr1"] == target_2.attrs["test_attr1"] # type: ignore target_1 = h5f[h5f.attrs["group1_ref"]] target_2 = h5f_2[h5f_2.attrs["group1_ref"]] - assert target_1.attrs["test_attr2"] == target_2.attrs["test_attr2"] + assert target_1.attrs["test_attr2"] == target_2.attrs["test_attr2"] # type: ignore def test_soft_links(): @@ -62,10 +58,10 @@ def test_soft_links(): h5f_2 = lindi.LindiH5pyFile.from_reference_file_system(rfs) g1 = h5f['group_target'] g2 = h5f_2['group_target'] - assert g1.attrs['foo'] == g2.attrs['foo'] + assert g1.attrs['foo'] == g2.attrs['foo'] # type: ignore h1 = h5f['soft_link'] h2 = h5f_2['soft_link'] - assert h1.attrs['foo'] == h2.attrs['foo'] + assert h1.attrs['foo'] == h2.attrs['foo'] # type: ignore # this is tricky: it seems that with h5py, the name of the soft link # is the source name. So the following assertion will fail. # assert h1.name == h2.name @@ -75,7 +71,7 @@ def test_soft_links(): assert isinstance(k2, h5py.SoftLink) ds1 = h5f['soft_link']['dataset1'] # type: ignore assert isinstance(ds1, h5py.Dataset) - ds2 = h5f_2['soft_link']['dataset1'] + ds2 = h5f_2['soft_link']['dataset1'] # type: ignore assert isinstance(ds2, h5py.Dataset) assert _arrays_are_equal(ds1[()], ds2[()]) ds1 = h5f['soft_link/dataset1'] @@ -213,12 +209,12 @@ def test_numpy_arrays(): filename, url=filename ) as store: # set url so that a reference file system can be created rfs = store.to_reference_file_system() - client = LindiZarrWrapper.from_reference_file_system(rfs) + client = lindi.LindiH5pyFile.from_reference_file_system(rfs) h5f = h5py.File(filename, "r") X1 = h5f["X"] assert isinstance(X1, h5py.Dataset) X2 = client["X"] - assert isinstance(X2, LindiZarrWrapperDataset) + assert isinstance(X2, lindi.LindiH5pyDataset) assert X1.shape == X2.shape assert X1.dtype == X2.dtype @@ -238,12 +234,12 @@ def test_nan_inf_attributes(): h5f = h5py.File(filename, "r") with LindiH5ZarrStore.from_file(filename, url=filename) as store: rfs = store.to_reference_file_system() - client = LindiZarrWrapper.from_reference_file_system(rfs) + client = lindi.LindiH5pyFile.from_reference_file_system(rfs) X1 = h5f["X"] assert isinstance(X1, h5py.Dataset) X2 = client["X"] - assert isinstance(X2, LindiZarrWrapperDataset) + assert isinstance(X2, lindi.LindiH5pyDataset) assert X2.attrs["nan"] == "NaN" assert X2.attrs["inf"] == "Infinity" @@ -272,4 +268,4 @@ def _arrays_are_equal(a, b): if __name__ == '__main__': - test_arrays_of_compound_dtype_with_references() + test_scalar_arrays() diff --git a/tests/test_pynwb.py b/tests/test_pynwb.py new file mode 100644 index 0000000..c08e3c9 --- /dev/null +++ b/tests/test_pynwb.py @@ -0,0 +1,114 @@ +from typing import Any +import tempfile +import lindi + + +def test_pynwb(): + from datetime import datetime + from uuid import uuid4 + + import numpy as np + from dateutil.tz import tzlocal + + from pynwb import NWBHDF5IO, NWBFile + from pynwb.ecephys import LFP, ElectricalSeries + + nwbfile: Any = NWBFile( + session_description="my first synthetic recording", + identifier=str(uuid4()), + session_start_time=datetime.now(tzlocal()), + experimenter=[ + "Baggins, Bilbo", + ], + lab="Bag End Laboratory", + institution="University of Middle Earth at the Shire", + experiment_description="I went on an adventure to reclaim vast treasures.", + session_id="LONELYMTN001", + ) + + device = nwbfile.create_device( + name="array", description="the best array", manufacturer="Probe Company 9000" + ) + + nwbfile.add_electrode_column(name="label", description="label of electrode") + + nshanks = 4 + nchannels_per_shank = 3 + electrode_counter = 0 + + for ishank in range(nshanks): + # create an electrode group for this shank + electrode_group = nwbfile.create_electrode_group( + name="shank{}".format(ishank), + description="electrode group for shank {}".format(ishank), + device=device, + location="brain area", + ) + # add electrodes to the electrode table + for ielec in range(nchannels_per_shank): + nwbfile.add_electrode( + group=electrode_group, + label="shank{}elec{}".format(ishank, ielec), + location="brain area", + ) + electrode_counter += 1 + + all_table_region = nwbfile.create_electrode_table_region( + region=list(range(electrode_counter)), # reference row indices 0 to N-1 + description="all electrodes", + ) + + raw_data = np.random.randn(50, 12) + raw_electrical_series = ElectricalSeries( + name="ElectricalSeries", + data=raw_data, + electrodes=all_table_region, + starting_time=0.0, # timestamp of the first sample in seconds relative to the session start time + rate=20000.0, # in Hz + ) + + nwbfile.add_acquisition(raw_electrical_series) + + lfp_data = np.random.randn(50, 12) + lfp_electrical_series = ElectricalSeries( + name="ElectricalSeries", + data=lfp_data, + electrodes=all_table_region, + starting_time=0.0, + rate=200.0, + ) + + lfp = LFP(electrical_series=lfp_electrical_series) + + ecephys_module = nwbfile.create_processing_module( + name="ecephys", description="processed extracellular electrophysiology data" + ) + ecephys_module.add(lfp) + + nwbfile.add_unit_column(name="quality", description="sorting quality") + + firing_rate = 20 + n_units = 10 + res = 1000 + duration = 20 + for n_units_per_shank in range(n_units): + spike_times = ( + np.where(np.random.rand((res * duration)) < (firing_rate / res))[0] / res + ) + nwbfile.add_unit(spike_times=spike_times, quality="good") + + with tempfile.TemporaryDirectory() as tmpdir: + nwb_fname = f"{tmpdir}/ecephys_tutorial.nwb" + with NWBHDF5IO(path=nwb_fname, mode="w") as io: + io.write(nwbfile) # type: ignore + # h5f = h5py.File(nwb_fname, "r") + with lindi.LindiH5ZarrStore.from_file(nwb_fname, url=nwb_fname) as store: + rfs = store.to_reference_file_system() + h5f_2 = lindi.LindiH5pyFile.from_reference_file_system(rfs) + with NWBHDF5IO(file=h5f_2, mode="r") as io: + nwbfile_2 = io.read() + print(nwbfile_2) + + +if __name__ == "__main__": + test_pynwb()