Skip to content

Commit

Permalink
Merge pull request #728 from ChristosT/add-hdf5-support
Browse files Browse the repository at this point in the history
Add fch5 an hdf5-based format for storing framecaches
  • Loading branch information
ChristosT authored Nov 14, 2024
2 parents 6213ab6 + 301a673 commit 9c595e1
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 34 deletions.
1 change: 1 addition & 0 deletions conda.recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ requirements:
- lxml >=4.9.2
- fast-histogram
- h5py
- hdf5plugin
- lmfit
- matplotlib-base
- numba
Expand Down
83 changes: 80 additions & 3 deletions hexrd/imageseries/load/framecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@
import numpy as np
from scipy.sparse import csr_matrix
import yaml
import h5py

from . import ImageSeriesAdapter
from ..imageseriesiter import ImageSeriesIterator
from .metadata import yamlmeta
from hexrd.utils.hdf5 import unwrap_h5_to_dict
from hexrd.utils.compatibility import h5py_read_string

import multiprocessing
from concurrent.futures import ThreadPoolExecutor


class FrameCacheImageSeriesAdapter(ImageSeriesAdapter):
"""collection of images in HDF5 format"""
Expand All @@ -26,13 +33,25 @@ def __init__(self, fname, style='npz', **kwargs):
self._framelist = []
self._framelist_was_loaded = False
self._load_framelist_lock = Lock()
# TODO extract style from filename ?
self._style = style.lower()

ncpus = multiprocessing.cpu_count()
self._max_workers = kwargs.get('max_workers', ncpus)

if style.lower() in ('yml', 'yaml', 'test'):
if self._style in ('yml', 'yaml', 'test'):
self._from_yml = True
self._load_yml()
else:
elif self._style == "npz":
self._from_yml = False
self._load_cache()
elif self._style == "fch5":
self._from_yml = False
self._load_cache()
else:
raise TypeError(f"Unknown style format for loading data: {style}."
"Known style formats: 'npz', 'fch5' 'yml', ",
"'yaml', 'test'")

def _load_yml(self):
with open(self._fname, "r") as f:
Expand All @@ -45,6 +64,29 @@ def _load_yml(self):
self._meta = yamlmeta(d['meta'], path=self._cache)

def _load_cache(self):
if self._style == 'fch5':
self._load_cache_fch5()
else:
self._load_cache_npz()

def _load_cache_fch5(self):
with h5py.File(self._fname, "r") as file:
if 'HEXRD_FRAMECACHE_VERSION' not in file.attrs.keys():
raise NotImplementedError("Unsupported file. "
"HEXRD_FRAMECACHE_VERSION "
"is missing!")
version = file.attrs.get('HEXRD_FRAMECACHE_VERSION', 0)
if version != 1:
raise NotImplementedError("Framecache version is not "
f"supported: {version}")

self._shape = file["shape"][()]
self._nframes = file["nframes"][()]
self._dtype = np.dtype(h5py_read_string(file["dtype"]))
self._meta = {}
unwrap_h5_to_dict(file["metadata"], self._meta)

def _load_cache_npz(self):
arrs = np.load(self._fname)
# HACK: while the loaded npz file has a getitem method
# that mimicks a dict, it doesn't have a "pop" method.
Expand Down Expand Up @@ -79,6 +121,41 @@ def _load_cache(self):

def _load_framelist(self):
"""load into list of csr sparse matrices"""
if self._style == 'fch5':
self._load_framelist_fch5()
else:
self._load_framelist_npz()

def _load_framelist_fch5(self):
self._framelist = [None] * self._nframes
with h5py.File(self._fname, "r") as file:
frame_id = file["frame_ids"]
data = file["data"]
indices = file["indices"]

def read_list_arrays_method_thread(i):
frame_data = data[frame_id[2*i]: frame_id[2*i+1]]
frame_indices = indices[frame_id[2*i]: frame_id[2*i+1]]
row = frame_indices[:, 0]
col = frame_indices[:, 1]
mat_data = frame_data[:, 0]
frame = csr_matrix((mat_data, (row, col)),
shape=self._shape,
dtype=self._dtype)
self._framelist[i] = frame
return

kwargs = {
"max_workers": self._max_workers,
}
with ThreadPoolExecutor(**kwargs) as executor:
# Evaluate the results via `list()`, so that if an exception is
# raised in a thread, it will be re-raised and visible to the
# user.
list(executor.map(read_list_arrays_method_thread,
range(self._nframes)))

def _load_framelist_npz(self):
self._framelist = []
if self._from_yml:
bpath = os.path.dirname(self._fname)
Expand Down Expand Up @@ -149,6 +226,6 @@ def __getitem__(self, key):
def __iter__(self):
return ImageSeriesIterator(self)

#@memoize
# @memoize
def __len__(self):
return self._nframes
Loading

0 comments on commit 9c595e1

Please sign in to comment.