Skip to content

Commit

Permalink
Merge pull request #611 from HEXRD/raw-frames-threadsafe
Browse files Browse the repository at this point in the history
Make __getitem__ for raw imageseries threadsafe
  • Loading branch information
donald-e-boyce authored Jan 25, 2024
2 parents d60f265 + f4a6789 commit 81bc561
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
10 changes: 8 additions & 2 deletions hexrd/imageseries/load/rawimage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
""" Adapter class for raw image reader"""
import os
import threading

import numpy as np
import yaml
Expand Down Expand Up @@ -30,6 +31,7 @@ def __init__(self, fname, **kwargs):
self._shape = tuple((int(si) for si in y['shape'].split()))
self._frame_size = self._shape[0] * self._shape[1]
self._frame_bytes = self._frame_size * self.dtype.itemsize
self._frame_read_lock = threading.Lock()
self.skipbytes = y['skip']
self._len = self._get_length()
self._meta = dict()
Expand Down Expand Up @@ -105,8 +107,12 @@ def __iter__(self):

def __getitem__(self, key):
count = key * self._frame_bytes + self.skipbytes
self.f.seek(count, 0)
frame = np.fromfile(self.f, self.dtype, count=self._frame_size)

# Ensure reading a frame the file is thread-safe
with self._frame_read_lock:
self.f.seek(count, 0)
frame = np.fromfile(self.f, self.dtype, count=self._frame_size)

return frame.reshape(self.shape)

@property
Expand Down
18 changes: 10 additions & 8 deletions hexrd/imageseries/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,16 @@ def __init__(self, ims, fname, **kwargs):
cache_file: str or Path, optional
name of the npz file to save the image data, if not given in the
`fname` argument; for YAML format (deprecated), this is required
max_workers: int, optional
The max number of worker threads for multithreading. Defaults to
the number of CPUs.
"""
Writer.__init__(self, ims, fname, **kwargs)
self._thresh = self._opts['threshold']
self._cache, self.cachename = self._set_cache()
self.max_workers = kwargs.get('max_workers', None)

ncpus = multiprocessing.cpu_count()
self.max_workers = kwargs.get('max_workers', ncpus)

def _set_cache(self):

Expand Down Expand Up @@ -254,9 +259,7 @@ def _write_frames(self):
buff_size = self._ims.shape[0]*self._ims.shape[1]
arrd = {}

ncpus = multiprocessing.cpu_count()
max_workers = ncpus if self.max_workers is None else self.max_workers
num_workers = min(max_workers, len(self._ims))
num_workers = min(self.max_workers, len(self._ims))

row_buffers = np.empty((num_workers, buff_size), dtype=np.uint16)
col_buffers = np.empty((num_workers, buff_size), dtype=np.uint16)
Expand All @@ -274,9 +277,6 @@ def extract_data(i):
cols = col_buffers[buffer_id]
vals = val_buffers[buffer_id]

# FIXME: in __init__() of ProcessedImageSeries:
# 'ProcessedImageSeries' object has no attribute '_adapter'

# wrapper to find (sparse) pixels above threshold
count = extract_ijv(self._ims[i], self._thresh,
rows, cols, vals)
Expand All @@ -302,7 +302,9 @@ def extract_data(i):
'initializer': assign_buffer_id,
}
with ThreadPoolExecutor(**kwargs) as executor:
executor.map(extract_data, range(len(self._ims)))
# Evaluate the results via `list()`, so that if an exception is
# raised in a thread, it will be re-raised and visible to the user.
list(executor.map(extract_data, range(len(self._ims))))

arrd['shape'] = self._ims.shape
arrd['nframes'] = len(self._ims)
Expand Down
5 changes: 4 additions & 1 deletion hexrd/instrument/hedm_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,10 @@ def extract_polar_maps(self, plane_data, imgser_dict,
func(task)
else:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
executor.map(func, tasks)
# Evaluate the results via `list()`, so that if an
# exception is raised in a thread, it will be re-raised
# and visible to the user.
list(executor.map(func, tasks))

ring_maps_panel[det_key] = ring_maps

Expand Down

0 comments on commit 81bc561

Please sign in to comment.