diff --git a/conda.recipe/meta.yaml b/conda.recipe/meta.yaml index e6ae499e6..b8c64eedc 100644 --- a/conda.recipe/meta.yaml +++ b/conda.recipe/meta.yaml @@ -42,6 +42,7 @@ requirements: - lxml >=4.9.2 - fast-histogram - h5py + - hdf5plugin - lmfit - matplotlib-base - numba diff --git a/hexrd/imageseries/load/framecache.py b/hexrd/imageseries/load/framecache.py index 93b3a8ad5..a42fc4d11 100644 --- a/hexrd/imageseries/load/framecache.py +++ b/hexrd/imageseries/load/framecache.py @@ -6,10 +6,17 @@ import numpy as np from scipy.sparse import csr_matrix import yaml +import h5py from . import ImageSeriesAdapter from ..imageseriesiter import ImageSeriesIterator from .metadata import yamlmeta +from hexrd.utils.hdf5 import unwrap_h5_to_dict +from hexrd.utils.compatibility import h5py_read_string + +import multiprocessing +from concurrent.futures import ThreadPoolExecutor + class FrameCacheImageSeriesAdapter(ImageSeriesAdapter): """collection of images in HDF5 format""" @@ -26,13 +33,25 @@ def __init__(self, fname, style='npz', **kwargs): self._framelist = [] self._framelist_was_loaded = False self._load_framelist_lock = Lock() + # TODO extract style from filename ? + self._style = style.lower() + + ncpus = multiprocessing.cpu_count() + self._max_workers = kwargs.get('max_workers', ncpus) - if style.lower() in ('yml', 'yaml', 'test'): + if self._style in ('yml', 'yaml', 'test'): self._from_yml = True self._load_yml() - else: + elif self._style == "npz": self._from_yml = False self._load_cache() + elif self._style == "fch5": + self._from_yml = False + self._load_cache() + else: + raise TypeError(f"Unknown style format for loading data: {style}." + "Known style formats: 'npz', 'fch5' 'yml', ", + "'yaml', 'test'") def _load_yml(self): with open(self._fname, "r") as f: @@ -45,6 +64,29 @@ def _load_yml(self): self._meta = yamlmeta(d['meta'], path=self._cache) def _load_cache(self): + if self._style == 'fch5': + self._load_cache_fch5() + else: + self._load_cache_npz() + + def _load_cache_fch5(self): + with h5py.File(self._fname, "r") as file: + if 'HEXRD_FRAMECACHE_VERSION' not in file.attrs.keys(): + raise NotImplementedError("Unsupported file. " + "HEXRD_FRAMECACHE_VERSION " + "is missing!") + version = file.attrs.get('HEXRD_FRAMECACHE_VERSION', 0) + if version != 1: + raise NotImplementedError("Framecache version is not " + f"supported: {version}") + + self._shape = file["shape"][()] + self._nframes = file["nframes"][()] + self._dtype = np.dtype(h5py_read_string(file["dtype"])) + self._meta = {} + unwrap_h5_to_dict(file["metadata"], self._meta) + + def _load_cache_npz(self): arrs = np.load(self._fname) # HACK: while the loaded npz file has a getitem method # that mimicks a dict, it doesn't have a "pop" method. @@ -79,6 +121,41 @@ def _load_cache(self): def _load_framelist(self): """load into list of csr sparse matrices""" + if self._style == 'fch5': + self._load_framelist_fch5() + else: + self._load_framelist_npz() + + def _load_framelist_fch5(self): + self._framelist = [None] * self._nframes + with h5py.File(self._fname, "r") as file: + frame_id = file["frame_ids"] + data = file["data"] + indices = file["indices"] + + def read_list_arrays_method_thread(i): + frame_data = data[frame_id[2*i]: frame_id[2*i+1]] + frame_indices = indices[frame_id[2*i]: frame_id[2*i+1]] + row = frame_indices[:, 0] + col = frame_indices[:, 1] + mat_data = frame_data[:, 0] + frame = csr_matrix((mat_data, (row, col)), + shape=self._shape, + dtype=self._dtype) + self._framelist[i] = frame + return + + kwargs = { + "max_workers": self._max_workers, + } + with ThreadPoolExecutor(**kwargs) as executor: + # Evaluate the results via `list()`, so that if an exception is + # raised in a thread, it will be re-raised and visible to the + # user. + list(executor.map(read_list_arrays_method_thread, + range(self._nframes))) + + def _load_framelist_npz(self): self._framelist = [] if self._from_yml: bpath = os.path.dirname(self._fname) @@ -149,6 +226,6 @@ def __getitem__(self, key): def __iter__(self): return ImageSeriesIterator(self) - #@memoize + # @memoize def __len__(self): return self._nframes diff --git a/hexrd/imageseries/save.py b/hexrd/imageseries/save.py index 092d30aa6..ea669c9a0 100644 --- a/hexrd/imageseries/save.py +++ b/hexrd/imageseries/save.py @@ -9,11 +9,13 @@ import numpy as np import h5py +import hdf5plugin import yaml from hexrd.matrixutil import extract_ijv +from hexrd.utils.hdf5 import unwrap_dict_to_h5 -MAX_NZ_FRACTION = 0.1 # 10% sparsity trigger for frame-cache write +MAX_NZ_FRACTION = 0.1 # 10% sparsity trigger for frame-cache write # ============================================================================= @@ -42,7 +44,6 @@ def write(ims, fname, fmt, **kwargs): # Registry class _RegisterWriter(abc.ABCMeta): - def __init__(cls, name, bases, attrs): abc.ABCMeta.__init__(cls, name, bases, attrs) _Registry.register(cls) @@ -50,6 +51,7 @@ def __init__(cls, name, bases, attrs): class _Registry(object): """Registry for imageseries writers""" + writer_registry = dict() @classmethod @@ -76,6 +78,7 @@ class Writer(object, metaclass=_RegisterWriter): kwargs: dict options specific to format """ + fmt = None def __init__(self, ims, fname, **kwargs): @@ -111,6 +114,7 @@ def fname_dir(self): def opts(self): return self._opts + class WriteH5(Writer): """Write imageseries in HDF5 file @@ -129,6 +133,7 @@ class WriteH5(Writer): shuffle: bool shuffle HDF5 data """ + fmt = 'hdf5' dflt_gzip = 1 dflt_chrows = 0 @@ -151,8 +156,9 @@ def write(self): g = f.create_group(self._path) s0, s1 = self._shape - ds = g.create_dataset('images', (self._nframes, s0, s1), self._dtype, - **self.h5opts) + ds = g.create_dataset( + 'images', (self._nframes, s0, s1), self._dtype, **self.h5opts + ) for i in range(self._nframes): ds[i, :, :] = self._ims[i] @@ -211,22 +217,34 @@ class WriteFrameCache(Writer): cache_file: str or Path, optional name of the npz file to save the image data, if not given in the `fname` argument; for YAML format (deprecated), this is required + style: str, type of file to use for saving. options are: + - 'npz' for saving in a numpy compressed file + - 'fch5' for saving in the HDF5-based frame-cache format max_workers: int, optional The max number of worker threads for multithreading. Defaults to the number of CPUs. """ + fmt = 'frame-cache' - def __init__(self, ims, fname, **kwargs): + def __init__(self, ims, fname, style='npz', **kwargs): Writer.__init__(self, ims, fname, **kwargs) self._thresh = self._opts['threshold'] self._cache, self.cachename = self._set_cache() ncpus = multiprocessing.cpu_count() self.max_workers = kwargs.get('max_workers', ncpus) + supported_formats = ['npz', 'fch5'] + if style not in supported_formats: + raise TypeError( + f"Unknown file style for writing framecache: {style}. " + f"Supported formats are {supported_formats}" + ) + self.style = style - def _set_cache(self): + self.hdf5_compression = hdf5plugin.Blosc(cname="zstd", clevel=5) + def _set_cache(self): cf = self.opts.get('cache_file') if cf is None: @@ -267,15 +285,40 @@ def _process_meta(self, save_omegas=False): return d def _write_yml(self): - datad = {'file': self._cachename, 'dtype': str(self._ims.dtype), - 'nframes': len(self._ims), 'shape': list(self._ims.shape)} + datad = { + 'file': self._cachename, + 'dtype': str(self._ims.dtype), + 'nframes': len(self._ims), + 'shape': list(self._ims.shape), + } info = {'data': datad, 'meta': self._process_meta(save_omegas=True)} with open(self._fname, "w") as f: yaml.safe_dump(info, f) def _write_frames(self): + if self.style == 'npz': + self._write_frames_npz() + elif self.style == 'fch5': + self._write_frames_fch5() + + def _check_sparsity(self, frame_id, count, buff_size): + # check the sparsity + # + # FIXME: formalize this a little better + # ???: maybe set a hard limit of total nonzeros for the imageseries + # ???: could pass as a kwarg on open + fullness = count / float(buff_size) + if fullness > MAX_NZ_FRACTION: + sparseness = 100.0 * (1 - fullness) + msg = "frame %d is %4.2f%% sparse (cutoff is 95%%)" % ( + frame_id, + sparseness, + ) + warnings.warn(msg) + + def _write_frames_npz(self): """also save shape array as originally done (before yaml)""" - buff_size = self._ims.shape[0]*self._ims.shape[1] + buff_size = self._ims.shape[0] * self._ims.shape[1] arrd = {} num_workers = min(self.max_workers, len(self._ims)) @@ -297,20 +340,9 @@ def extract_data(i): vals = val_buffers[buffer_id] # wrapper to find (sparse) pixels above threshold - count = extract_ijv(self._ims[i], self._thresh, - rows, cols, vals) - - # check the sparsity - # - # FIXME: formalize this a little better - # ???: maybe set a hard limit of total nonzeros for the imageseries - # ???: could pass as a kwarg on open - fullness = count / float(buff_size) - if fullness > MAX_NZ_FRACTION: - sparseness = 100.*(1 - fullness) - msg = "frame %d is %4.2f%% sparse (cutoff is 95%%)" \ - % (i, sparseness) - warnings.warn(msg) + count = extract_ijv(self._ims[i], self._thresh, rows, cols, vals) + + self._check_sparsity(i, count, buff_size) arrd[f'{i}_row'] = rows[:count].copy() arrd[f'{i}_col'] = cols[:count].copy() @@ -331,6 +363,117 @@ def extract_data(i): arrd.update(self._process_meta()) np.savez_compressed(self.cache, **arrd) + def _write_frames_fch5(self): + """Write framecache into an hdf5 file. The file will use three + datasets for the framecache: + - 'data': (m,1) array holding the datavalues of all frames. `m` is + evaluated upon runtime + - 'indices': (m,2) array holding the row& col information for the + values in data. 'data' together within 'indices' represent tha data + using the CSR format for sparse matrices. + - 'frame_ids`: (2*nframes) holds the range that the i-th frame + occupies in the above arrays. i.e. the information of the i-th frame + can be accessed using: + + data_i = data[frame_ids[2*i]:frame_ids[2*i+1]] and + indices_i = indices[frame_ids[2*i]:frame_ids[2*i+1]] + """ + max_frame_size = self._ims.shape[0] * self._ims.shape[1] + nframes = len(self._ims) + shape = self._ims.shape + data_dtype = self._ims.dtype + + frame_indices = np.empty((2 * nframes,), dtype=np.uint64) + data_dataset = None + indices_dataset = None + file_position = 0 + total_size = 0 + + common_lock = threading.Lock() + thread_local = threading.local() + + # creating an array in memory will fail if data is too big or threshold + # too low, so we write to the file while iterating the frames + with h5py.File(self.cache, "w") as h5f: + h5f.attrs['HEXRD_FRAMECACHE_VERSION'] = 1 + h5f["shape"] = shape + h5f["nframes"] = nframes + h5f["dtype"] = str(np.dtype(self._ims.dtype)).encode("utf-8") + metadata = h5f.create_group("metadata") + unwrap_dict_to_h5(metadata, self._meta.copy()) + + def initialize_buffers(): + thread_local.data = np.empty( + (max_frame_size, 1), dtype=self._ims.dtype + ) + thread_local.indices = np.empty( + (max_frame_size, 2), dtype=np.uint16 + ) + + def single_array_write_thread(i): + nonlocal file_position, total_size + im = self._ims[i] + row_slice = thread_local.indices[:, 0] + col_slice = thread_local.indices[:, 1] + data_slice = thread_local.data[:, 0] + count = extract_ijv( + im, self._thresh, row_slice, col_slice, data_slice + ) + + self._check_sparsity(i, count, max_frame_size) + + # get the range this thread is doing to write into the file + start_file = 0 + end_file = 0 + with common_lock: + start_file = file_position + file_position += count + end_file = file_position + total_size += end_file - start_file + # write within the appropriate ranges + data_dataset[start_file:end_file, :] = thread_local.data[ + :count, : + ] + indices_dataset[start_file:end_file, :] = thread_local.indices[ + :count, : + ] + frame_indices[2 * i] = start_file + frame_indices[2 * i + 1] = end_file + + kwargs = { + "max_workers": self.max_workers, + "initializer": initialize_buffers, + } + + data_dataset = h5f.create_dataset( + "data", + shape=(nframes * max_frame_size, 1), + dtype=data_dtype, + compression=self.hdf5_compression, + ) + indices_dataset = h5f.create_dataset( + "indices", + shape=(nframes * max_frame_size, 2), + dtype=np.uint16, + compression=self.hdf5_compression, + ) + with ThreadPoolExecutor(**kwargs) as executor: + # Evaluate the results via `list()`, so that if an exception is + # raised in a thread, it will be re-raised and visible to + # the user. + list(executor.map(single_array_write_thread, range(nframes))) + + # update the sizes of the dataset to match the amount of data + # that have been actually written + data_dataset.resize(total_size, axis=0) + indices_dataset.resize(total_size, axis=0) + + h5f.create_dataset( + "frame_ids", + data=frame_indices, + compression=self.hdf5_compression, + ) + def write(self, output_yaml=False): """writes frame cache for imageseries @@ -339,7 +482,6 @@ def write(self, output_yaml=False): self._write_frames() if output_yaml: warnings.warn( - "YAML output for frame-cache is deprecated", - DeprecationWarning + "YAML output for frame-cache is deprecated", DeprecationWarning ) self._write_yml() diff --git a/setup.py b/setup.py index db7c45a6e..7e0e0ef59 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ 'fast-histogram', 'h5py<3.12', # Currently, h5py 3.12 on Windows fails to import. # We can remove this version pin when that is fixed. + 'hdf5plugin', 'lmfit', 'matplotlib', 'numba', diff --git a/tests/imageseries/test_formats.py b/tests/imageseries/test_formats.py index ae33bb8f8..6d03c6453 100644 --- a/tests/imageseries/test_formats.py +++ b/tests/imageseries/test_formats.py @@ -85,6 +85,7 @@ def setUp(self): self.fcfile = os.path.join(self.tmpdir, 'frame-cache.npz') self.fmt = 'frame-cache' self.thresh = 0.5 + self.style = 'npz' self.cache_file='frame-cache.npz' _, self.is_a = make_array_ims() @@ -93,9 +94,9 @@ def tearDown(self): def test_fmtfc(self): """save/load frame-cache format""" - imageseries.write(self.is_a, self.fcfile, self.fmt, + imageseries.write(self.is_a, self.fcfile, self.fmt, style=self.style, threshold=self.thresh, cache_file=self.cache_file) - is_fc = imageseries.open(self.fcfile, self.fmt) + is_fc = imageseries.open(self.fcfile, self.fmt, style=self.style) diff = compare(self.is_a, is_fc) self.assertAlmostEqual(diff, 0., "frame-cache reconstruction failed") self.assertTrue(compare_meta(self.is_a, is_fc)) @@ -104,9 +105,9 @@ def test_fmtfc_nocache_file(self): """save/load frame-cache format with no cache_file arg""" imageseries.write( self.is_a, self.fcfile, self.fmt, - threshold=self.thresh + threshold=self.thresh, style=self.style ) - is_fc = imageseries.open(self.fcfile, self.fmt) + is_fc = imageseries.open(self.fcfile, self.fmt, style=self.style) diff = compare(self.is_a, is_fc) self.assertAlmostEqual(diff, 0., "frame-cache reconstruction failed") self.assertTrue(compare_meta(self.is_a, is_fc)) @@ -117,11 +118,32 @@ def test_fmtfc_nparray(self): npa = np.array([0,2.0,1.3]) self.is_a.metadata[key] = npa - imageseries.write(self.is_a, self.fcfile, self.fmt, + imageseries.write(self.is_a, self.fcfile, self.fmt, style=self.style, threshold=self.thresh, cache_file=self.cache_file ) - is_fc = imageseries.open(self.fcfile, self.fmt) + is_fc = imageseries.open(self.fcfile, self.fmt, style=self.style) meta = is_fc.metadata diff = np.linalg.norm(meta[key] - npa) self.assertAlmostEqual(diff, 0., "frame-cache numpy array metadata failed") + + +class TestFormatFrameCache_FCH5(TestFormatFrameCache): + + def setUp(self): + self.fcfile = os.path.join(self.tmpdir, 'frame-cache.fch5') + self.fmt = 'frame-cache' + self.style = 'fch5' + self.thresh = 0.5 + self.cache_file = 'frame-cache.fch5' + _, self.is_a = make_array_ims() + + def test_fmtfc_nested_metadata(self): + """frame-cache format with nested metadata""" + metadata = {'int': 1, 'array': np.array([1, 2, 3])} + self.is_a.metadata["key"] = metadata + + imageseries.write(self.is_a, self.fcfile, self.fmt, style=self.style, + threshold=self.thresh, cache_file=self.cache_file) + is_fc = imageseries.open(self.fcfile, self.fmt, style=self.style) + self.assertTrue(compare_meta(self.is_a, is_fc))