From 86ad79866591cd67877b302448c83457179d6a65 Mon Sep 17 00:00:00 2001 From: Patrick Avery Date: Thu, 5 Sep 2024 04:44:29 -0500 Subject: [PATCH] Make framelist loading thread-safe This is so that `__getitem__` is thread-safe again. Only one thread should load the framelist, and the other threads wait for it to finish. Signed-off-by: Patrick Avery --- hexrd/imageseries/load/framecache.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/hexrd/imageseries/load/framecache.py b/hexrd/imageseries/load/framecache.py index 9dee77a3..93b3a8ad 100644 --- a/hexrd/imageseries/load/framecache.py +++ b/hexrd/imageseries/load/framecache.py @@ -1,6 +1,7 @@ """Adapter class for frame caches """ import os +from threading import Lock import numpy as np from scipy.sparse import csr_matrix @@ -23,6 +24,9 @@ def __init__(self, fname, style='npz', **kwargs): """ self._fname = fname self._framelist = [] + self._framelist_was_loaded = False + self._load_framelist_lock = Lock() + if style.lower() in ('yml', 'yaml', 'test'): self._from_yml = True self._load_yml() @@ -95,11 +99,6 @@ def _load_framelist(self): dtype=self._dtype) self._framelist.append(frame) - @property - def _framelist_was_loaded(self): - # Just assume that if the framelist is empty, it wasn't loaded... - return len(self._framelist) > 0 - @property def metadata(self): """(read-only) Image sequence metadata @@ -131,10 +130,20 @@ def dtype(self): def shape(self): return self._shape - def __getitem__(self, key): + def _load_framelist_if_needed(self): if not self._framelist_was_loaded: - # Load the framelist now - self._load_framelist() + # Only one thread should load the framelist. + # Acquire the lock for loading the framelist. + with self._load_framelist_lock: + # It is possible that another thread already loaded + # the framelist by the time this lock was acquired. + # Check again. + if not self._framelist_was_loaded: + self._load_framelist() + self._framelist_was_loaded = True + + def __getitem__(self, key): + self._load_framelist_if_needed() return self._framelist[key].toarray() def __iter__(self):