diff --git a/hexrd/imageseries/save.py b/hexrd/imageseries/save.py index 59e223346..0de243f51 100644 --- a/hexrd/imageseries/save.py +++ b/hexrd/imageseries/save.py @@ -195,11 +195,16 @@ def __init__(self, ims, fname, **kwargs): cache_file: str or Path, optional name of the npz file to save the image data, if not given in the `fname` argument; for YAML format (deprecated), this is required + max_workers: int, optional + The max number of worker threads for multithreading. Defaults to + the number of CPUs. """ Writer.__init__(self, ims, fname, **kwargs) self._thresh = self._opts['threshold'] self._cache, self.cachename = self._set_cache() - self.max_workers = kwargs.get('max_workers', None) + + ncpus = multiprocessing.cpu_count() + self.max_workers = kwargs.get('max_workers', ncpus) def _set_cache(self): @@ -254,9 +259,7 @@ def _write_frames(self): buff_size = self._ims.shape[0]*self._ims.shape[1] arrd = {} - ncpus = multiprocessing.cpu_count() - max_workers = ncpus if self.max_workers is None else self.max_workers - num_workers = min(max_workers, len(self._ims)) + num_workers = min(self.max_workers, len(self._ims)) row_buffers = np.empty((num_workers, buff_size), dtype=np.uint16) col_buffers = np.empty((num_workers, buff_size), dtype=np.uint16) @@ -274,9 +277,6 @@ def extract_data(i): cols = col_buffers[buffer_id] vals = val_buffers[buffer_id] - # FIXME: in __init__() of ProcessedImageSeries: - # 'ProcessedImageSeries' object has no attribute '_adapter' - # wrapper to find (sparse) pixels above threshold count = extract_ijv(self._ims[i], self._thresh, rows, cols, vals) @@ -302,7 +302,9 @@ def extract_data(i): 'initializer': assign_buffer_id, } with ThreadPoolExecutor(**kwargs) as executor: - executor.map(extract_data, range(len(self._ims))) + # Evaluate the results via `list()`, so that if an exception is + # raised in a thread, it will be re-raised and visible to the user. + list(executor.map(extract_data, range(len(self._ims)))) arrd['shape'] = self._ims.shape arrd['nframes'] = len(self._ims) diff --git a/hexrd/instrument/hedm_instrument.py b/hexrd/instrument/hedm_instrument.py index 8f5b7e1ca..46350f5df 100644 --- a/hexrd/instrument/hedm_instrument.py +++ b/hexrd/instrument/hedm_instrument.py @@ -1257,7 +1257,10 @@ def extract_polar_maps(self, plane_data, imgser_dict, func(task) else: with ThreadPoolExecutor(max_workers=max_workers) as executor: - executor.map(func, tasks) + # Evaluate the results via `list()`, so that if an + # exception is raised in a thread, it will be re-raised + # and visible to the user. + list(executor.map(func, tasks)) ring_maps_panel[det_key] = ring_maps