Skip to content

Commit

Permalink
fix refs
Browse files Browse the repository at this point in the history
  • Loading branch information
magland committed Mar 19, 2024
1 parent cffc1e3 commit f74ef59
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 34 deletions.
44 changes: 19 additions & 25 deletions devel/test_load_pynwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,44 +5,33 @@
import h5py
import lindi
import json
import remfile


def test_load_pynwb():
# https://neurosift.app/?p=/nwb&dandisetId=000939&dandisetVersion=0.240318.1555&url=https://api.dandiarchive.org/api/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/download/
# url_nwb = "https://api.dandiarchive.org/api/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/download/"
url_nwb = "https://api.dandiarchive.org/api/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/download/"
url = "https://kerchunk.neurosift.org/dandi/dandisets/000939/assets/11f512ba-5bcf-4230-a8cb-dc8d36db38cb/zarr.json"

thisdir = os.path.dirname(os.path.abspath(__file__))
fname = thisdir + "/test.zarr.json"
if not os.path.exists(fname):
_download_file(url, fname)

# remf = remfile.File(url_nwb)
# h5f0 = h5py.File(remf, mode="r")
h5f0 = h5py.File("/home/magland/test.nwb", mode="r")
remf = remfile.File(url_nwb)
h5f0 = h5py.File(remf, mode="r")
h5f = lindi.LindiH5pyFile.from_h5py_file(h5f0)
with pynwb.NWBHDF5IO(file=h5f, mode="r") as io:
nwb = io.read()
print(nwb)
for k in nwb.fields:
print(
f"________________________________ {k} __________________________________"
)
print(getattr(nwb, k))

print("-------------------------------------------")
store = lindi.LindiH5ZarrStore.from_file(
"/home/magland/test.nwb", url="/home/magland/test.nwb"
)
store = lindi.LindiH5ZarrStore.from_file(url_nwb, url=url_nwb)
rfs = store.to_reference_file_system()
with open("test_rfs.zarr.json", "w") as f:
json.dump(rfs, f, indent=2)
hf5_rfs = lindi.LindiH5pyFile.from_reference_file_system(rfs)

_compare_h5py_files(h5f0, hf5_rfs)
_compare_h5py_files(h5f, hf5_rfs)

with pynwb.NWBHDF5IO(file=hf5_rfs, mode="r") as io:
nwb = io.read()
with pynwb.NWBHDF5IO(file=hf5_rfs, mode="r") as io1:
nwb = io1.read()
print(nwb)
for k in nwb.fields:
print(
Expand Down Expand Up @@ -115,25 +104,25 @@ def _compare_h5py_groups(g1: h5py.Group, g2: h5py.Group, label: str):
if isinstance(obj1, h5py.Group):
obj1x = g1.get(k, getlink=True)
obj2x = g2.get(k, getlink=True)
if isinstance(obj1x, h5py.SoftLink):
if isinstance(obj2x, lindi.LindiH5pySoftLink):
if isinstance(obj1x, h5py.SoftLink) or isinstance(obj1x, lindi.LindiH5pySoftLink):
if isinstance(obj2x, h5py.SoftLink) or isinstance(obj2x, lindi.LindiH5pySoftLink):
pass
else:
print(f"*************** Link type mismatch for {k}")
print(type(obj1x))
print(type(obj2x))
elif isinstance(obj1x, h5py.HardLink):
if isinstance(obj2x, lindi.LindiH5pyHardLink):
elif isinstance(obj1x, h5py.HardLink) or isinstance(obj1x, lindi.LindiH5pyHardLink):
if isinstance(obj2x, h5py.HardLink) or isinstance(obj2x, lindi.LindiH5pyHardLink):
pass
else:
print(f"*************** Hard link type mismatch for {k}")
print(type(obj1x))
print(type(obj2x))
elif isinstance(obj2x, lindi.LindiH5pySoftLink):
elif isinstance(obj2x, h5py.SoftLink) or isinstance(obj2x, lindi.LindiH5pySoftLink):
print(f"*************** Link type mismatch for {k}")
print(type(obj1x))
print(type(obj2x))
elif isinstance(obj2x, lindi.LindiH5pyHardLink):
elif isinstance(obj2x, h5py.HardLink) or isinstance(obj2x, lindi.LindiH5pyHardLink):
print(f"*************** Link type mismatch for {k}")
print(type(obj1x))
print(type(obj2x))
Expand All @@ -159,6 +148,11 @@ def _compare_h5py_datasets(d1: h5py.Dataset, d2: h5py.Dataset, label: str):
print("*************** Ndim mismatch")
if d1.maxshape != d2.maxshape:
print("*************** Maxshape mismatch")
if d1.size and d1.size < 100:
if not _check_equal(d1[()], d2[()]):
print("*************** Data mismatch")
print(f" h5f1: {d1[()].ravel()[:5]}")
print(f" h5f2: {d2[()].ravel()[:5]}")


def _download_file(url, fname):
Expand Down
2 changes: 1 addition & 1 deletion lindi/LindiH5ZarrStore/LindiH5ZarrStore.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def listdir(self, path: str = "") -> List[str]:
if self._h5f is None:
raise Exception("Store is closed")
try:
item = self._h5f[path]
item = self._h5f['/' + path]
except KeyError:
return []
if isinstance(item, h5py.Group):
Expand Down
10 changes: 9 additions & 1 deletion lindi/LindiH5pyFile/LindiH5pyDataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import TYPE_CHECKING, Union
import h5py

from .LindiH5pyAttributes import LindiH5pyAttributes
from .LindiH5pyReference import LindiH5pyReference
from ..LindiZarrWrapper import LindiZarrWrapperDataset
from ..LindiZarrWrapper import LindiZarrWrapperReference


if TYPE_CHECKING:
Expand Down Expand Up @@ -61,4 +64,9 @@ def attrs(self): # type: ignore
return LindiH5pyAttributes(self._dataset_object.attrs)

def __getitem__(self, args, new_dtype=None):
return self._dataset_object.__getitem__(args, new_dtype)
ret = self._dataset_object.__getitem__(args, new_dtype)
if isinstance(self._dataset_object, LindiZarrWrapperDataset):
if isinstance(ret, dict):
if '_REFERENCE' in ret:
ret = LindiH5pyReference(LindiZarrWrapperReference(ret['_REFERENCE']))
return ret
23 changes: 22 additions & 1 deletion lindi/LindiH5pyFile/LindiH5pyFile.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from typing import Union
import h5py
import zarr

from .LindiH5pyGroup import LindiH5pyGroup
from ..LindiZarrWrapper import LindiZarrWrapper
from .LindiH5pyDataset import LindiH5pyDataset
from ..LindiZarrWrapper import LindiZarrWrapper, LindiZarrWrapperGroup, LindiZarrWrapperDataset
from .LindiH5pyAttributes import LindiH5pyAttributes
from .LindiH5pyReference import LindiH5pyReference


class LindiH5pyFile(h5py.File):
Expand Down Expand Up @@ -96,6 +99,24 @@ def __repr__(self):
# Group methods

def __getitem__(self, name):
if isinstance(name, LindiH5pyReference):
assert isinstance(self._file_object, LindiZarrWrapper)
x = self._file_object[name._reference]
if isinstance(x, LindiZarrWrapperGroup):
return LindiH5pyGroup(x, self)
elif isinstance(x, LindiZarrWrapperDataset):
return LindiH5pyDataset(x, self)
else:
raise Exception(f"Unexpected type for resolved reference at path {name}: {type(x)}")
elif isinstance(name, h5py.Reference):
assert isinstance(self._file_object, h5py.File)
x = self._file_object[name]
if isinstance(x, h5py.Group):
return LindiH5pyGroup(x, self)
elif isinstance(x, h5py.Dataset):
return LindiH5pyDataset(x, self)
else:
raise Exception(f"Unexpected type for resolved reference at path {name}: {type(x)}")
return self._the_group[name]

def get(self, name, default=None, getclass=False, getlink=False):
Expand Down
8 changes: 2 additions & 6 deletions lindi/LindiH5pyFile/LindiH5pyGroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .LindiH5pyLink import LindiH5pyHardLink, LindiH5pySoftLink
from ..LindiZarrWrapper import LindiZarrWrapperGroup
from .LindiH5pyAttributes import LindiH5pyAttributes
from .LindiH5pyReference import LindiH5pyReference


if TYPE_CHECKING:
Expand All @@ -25,7 +24,7 @@ def __init__(self, _group_object: Union[h5py.Group, LindiZarrWrapperGroup], _fil

def __getitem__(self, name):
if isinstance(self._group_object, h5py.Group):
if isinstance(name, h5py.h5r.Reference) or isinstance(name, (bytes, str)):
if isinstance(name, (bytes, str)):
x = self._group_object[name]
else:
raise TypeError(
Expand All @@ -39,10 +38,7 @@ def __getitem__(self, name):
else:
raise Exception(f"Unknown type: {type(x)}")
elif isinstance(self._group_object, LindiZarrWrapperGroup):
if isinstance(name, LindiH5pyReference):
# is this the right thing to do?
x = self._group_object.file[name._reference]
elif isinstance(name, (bytes, str)):
if isinstance(name, (bytes, str)):
x = self._group_object[name]
else:
raise TypeError(
Expand Down
4 changes: 4 additions & 0 deletions tests/test_with_real_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,7 @@ def test_with_real_data():

root = zarr.open(store, mode="r")
_hdf5_visit_items(h5f, lambda key, item: _compare_item_2(item, root[key]))


if __name__ == "__main__":
test_with_real_data()

0 comments on commit f74ef59

Please sign in to comment.