Skip to content

Commit

Permalink
Ensure errors during multithreading get propagated
Browse files Browse the repository at this point in the history
Since we were not evaluating the outputs of the threads before, the
errors would not get propagated, leading to confusion.

Propagating the errors ensures that users can see (and report) when
they encounter an error.

Signed-off-by: Patrick Avery <[email protected]>
  • Loading branch information
psavery committed Jan 25, 2024
1 parent cc00d7d commit f4a6789
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
18 changes: 10 additions & 8 deletions hexrd/imageseries/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,16 @@ def __init__(self, ims, fname, **kwargs):
cache_file: str or Path, optional
name of the npz file to save the image data, if not given in the
`fname` argument; for YAML format (deprecated), this is required
max_workers: int, optional
The max number of worker threads for multithreading. Defaults to
the number of CPUs.
"""
Writer.__init__(self, ims, fname, **kwargs)
self._thresh = self._opts['threshold']
self._cache, self.cachename = self._set_cache()
self.max_workers = kwargs.get('max_workers', None)

ncpus = multiprocessing.cpu_count()
self.max_workers = kwargs.get('max_workers', ncpus)

def _set_cache(self):

Expand Down Expand Up @@ -254,9 +259,7 @@ def _write_frames(self):
buff_size = self._ims.shape[0]*self._ims.shape[1]
arrd = {}

ncpus = multiprocessing.cpu_count()
max_workers = ncpus if self.max_workers is None else self.max_workers
num_workers = min(max_workers, len(self._ims))
num_workers = min(self.max_workers, len(self._ims))

row_buffers = np.empty((num_workers, buff_size), dtype=np.uint16)
col_buffers = np.empty((num_workers, buff_size), dtype=np.uint16)
Expand All @@ -274,9 +277,6 @@ def extract_data(i):
cols = col_buffers[buffer_id]
vals = val_buffers[buffer_id]

# FIXME: in __init__() of ProcessedImageSeries:
# 'ProcessedImageSeries' object has no attribute '_adapter'

# wrapper to find (sparse) pixels above threshold
count = extract_ijv(self._ims[i], self._thresh,
rows, cols, vals)
Expand All @@ -302,7 +302,9 @@ def extract_data(i):
'initializer': assign_buffer_id,
}
with ThreadPoolExecutor(**kwargs) as executor:
executor.map(extract_data, range(len(self._ims)))
# Evaluate the results via `list()`, so that if an exception is
# raised in a thread, it will be re-raised and visible to the user.
list(executor.map(extract_data, range(len(self._ims))))

arrd['shape'] = self._ims.shape
arrd['nframes'] = len(self._ims)
Expand Down
5 changes: 4 additions & 1 deletion hexrd/instrument/hedm_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,10 @@ def extract_polar_maps(self, plane_data, imgser_dict,
func(task)
else:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
executor.map(func, tasks)
# Evaluate the results via `list()`, so that if an
# exception is raised in a thread, it will be re-raised
# and visible to the user.
list(executor.map(func, tasks))

ring_maps_panel[det_key] = ring_maps

Expand Down

0 comments on commit f4a6789

Please sign in to comment.