From 781d18232c3807e974dbb68bdf6eed26e59d0ed8 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Fri, 25 Oct 2024 10:49:11 +0100 Subject: [PATCH] Added lazy loading functionality --- stride/__init__.py | 9 ++ stride/optimisation/optimisation_loop.py | 14 +-- stride/problem/acquisitions.py | 90 ++++++++++++++++--- stride/problem/base.py | 6 +- stride/problem/data.py | 12 +-- stride/problem/geometry.py | 6 +- stride/problem/transducer_types/transducer.py | 2 +- stride/problem/transducers.py | 4 +- 8 files changed, 110 insertions(+), 33 deletions(-) diff --git a/stride/__init__.py b/stride/__init__.py index 8666b698..0d020bf5 100644 --- a/stride/__init__.py +++ b/stride/__init__.py @@ -187,6 +187,8 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa select_shots : dict, optional Rules for selecting available shots per iteration, defaults to taking all shots. For details on this see :func:`~stride.problem.acquisitions.Acquisitions.select_shot_ids`. + lazy_loading : bool, optional + Whether to load shot data every iteration to save memory. dump : bool, optional Whether or not to save to disk the updated variable after every iteration. f_min : float, optional @@ -216,6 +218,7 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa restart = kwargs.pop('restart', None) restart_id = kwargs.pop('restart_id', -1) + lazy_loading = kwargs.pop('lazy_loading', False) dump = kwargs.pop('dump', True) safe = kwargs.pop('safe', True) @@ -290,6 +293,9 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa shot_ids = problem.acquisitions.select_shot_ids(**select_shots) num_shots = len(shot_ids) + if lazy_loading: + problem.acquisitions.load(shot_ids=shot_ids, lazy_loading=False) + @runtime.async_for(shot_ids, safe=safe) async def loop(worker, shot_id): _kwargs = kwargs.copy() @@ -482,4 +488,7 @@ async def loop(worker, shot_id): optimisation_loop.num_blocks, iteration.total_loss, prev_loss)) logger.perf('====================================================================') + if lazy_loading: + problem.acquisitions.deallocate(shot_ids=shot_ids) + iteration.clear_run() diff --git a/stride/optimisation/optimisation_loop.py b/stride/optimisation/optimisation_loop.py index da084a24..24ab71b4 100644 --- a/stride/optimisation/optimisation_loop.py +++ b/stride/optimisation/optimisation_loop.py @@ -65,7 +65,7 @@ def __get_desc__(self, **kwargs): return description - def __set_desc__(self, description): + def __set_desc__(self, description, **kwargs): self.id = description.id self.submitted_shots = description.submitted_shots self.completed_shots = description.completed_shots @@ -252,7 +252,7 @@ def __get_desc__(self, **kwargs): return description - def __set_desc__(self, description): + def __set_desc__(self, description, **kwargs): self.id = description.id self.abs_id = description.abs_id @@ -260,7 +260,7 @@ def __set_desc__(self, description): for run_desc in description.runs: self._curr_run_idx += 1 run = IterationRun(self._curr_run_idx, self) - run.__set_desc__(run_desc) + run.__set_desc__(run_desc, **kwargs) self._runs[self._curr_run_idx] = run _serialisation_attrs = ['id', 'abs_id'] @@ -453,14 +453,14 @@ def __get_desc__(self, **kwargs): return description - def __set_desc__(self, description): + def __set_desc__(self, description, **kwargs): self.id = description.id self._num_iterations = description.num_iterations for iter_desc in description.iterations: iteration = Iteration(iter_desc.id, iter_desc.abs_id, self, self._optimisation_loop) - iteration.__set_desc__(iter_desc) + iteration.__set_desc__(iter_desc, **kwargs) self._iterations[iteration.id] = iteration self._current_iteration = self._iterations[description.current_iteration.id] @@ -671,13 +671,13 @@ def __get_desc__(self, **kwargs): return description - def __set_desc__(self, description): + def __set_desc__(self, description, **kwargs): self.running_id = description.running_id self._num_blocks = description.num_blocks for block_desc in description.blocks: block = Block(block_desc.id, self) - block.__set_desc__(block_desc) + block.__set_desc__(block_desc, **kwargs) self._blocks[block.id] = block self._current_block = self._blocks[description.current_block.id] diff --git a/stride/problem/acquisitions.py b/stride/problem/acquisitions.py index 13a0180e..00df8505 100644 --- a/stride/problem/acquisitions.py +++ b/stride/problem/acquisitions.py @@ -316,6 +316,23 @@ def sub_problem(self, shot, sub_problem): return shot + def deallocate(self): + """ + Deallocate memory associated with shot. + + Returns + ------- + + """ + if self.wavelets is not None: + self.wavelets.deallocate() + + if self.observed is not None: + self.observed.deallocate() + + if self.delays is not None: + self.delays.deallocate() + def plot(self, **kwargs): """ Plot wavelets and observed for this shot if they are allocated. @@ -423,7 +440,7 @@ def __get_desc__(self, **kwargs): return description - def __set_desc__(self, description): + def __set_desc__(self, description, **kwargs): self.id = description.id for source_id in description.source_ids: @@ -440,17 +457,19 @@ def __set_desc__(self, description): receiver = None self._receivers[receiver_id] = receiver + lazy_loading = kwargs.pop('lazy_loading', False) + self.wavelets = Traces(name='wavelets', transducer_ids=self.source_ids, grid=self.grid) - if 'wavelets' in description: - self.wavelets.__set_desc__(description.wavelets) + if 'wavelets' in description and not lazy_loading: + self.wavelets.__set_desc__(description.wavelets, **kwargs) self.observed = Traces(name='observed', transducer_ids=self.receiver_ids, grid=self.grid) - if 'observed' in description: - self.observed.__set_desc__(description.observed) + if 'observed' in description and not lazy_loading: + self.observed.__set_desc__(description.observed, **kwargs) self.delays = Traces(name='delays', transducer_ids=self.source_ids, shape=(len(self.source_ids), 1), grid=self.grid) - if 'delays' in description: - self.delays.__set_desc__(description.delays) + if 'delays' in description and not lazy_loading: + self.delays.__set_desc__(description.delays, **kwargs) class Sequence(ProblemBase): @@ -657,7 +676,7 @@ def __get_desc__(self, **kwargs): return description - def __set_desc__(self, description): + def __set_desc__(self, description, **kwargs): self.id = description.id self.acq = description.acq @@ -707,6 +726,7 @@ def __init__(self, name='acquisitions', problem=None, **kwargs): self._sequences = OrderedDict() self._shot_selection = [] self._sequence_selection = [] + self._prev_load = None, None @property def shots(self): @@ -1026,6 +1046,25 @@ def default(self): sources=[source], receivers=receivers, geometry=self._geometry, problem=self.problem)) + def deallocate(self, shot_ids=None): + """ + Deallocate memory associated with shots. + + Parameters + ---------- + shot_ids : list, optional + Set of shot IDs to deallocate. + + Returns + ------- + + """ + if shot_ids is None: + shot_ids = self.shot_ids + + for shot_id in shot_ids: + self._shots[shot_id].deallocate() + def plot(self, **kwargs): """ Plot wavelets and observed for for all shots if they are allocated. @@ -1172,6 +1211,35 @@ def sub_problem(self, shot, sub_problem): return sub_acquisitions + def load(self, *args, **kwargs): + """ + Load the object using ``__set_desc__`` to digest the description. + + See :class:`~mosaic.file_manipulation.h5.HDF5` for more information on the parameters of this method. + + Parameters + ---------- + shot_ids : list, optional + List of shot IDs to load. + + Returns + ------- + + """ + shot_ids = kwargs.pop('shot_ids', None) + + prev_args, prev_kwargs = self._prev_load + if prev_args is not None: + args = args + prev_args[len(args):] if len(args) < len(prev_args) else args + if prev_kwargs is not None: + kwargs_ = prev_kwargs.copy() + kwargs_.update(kwargs) + kwargs = kwargs_ + + super().load(*args, filter={'shots': shot_ids} if shot_ids is not None else None, **kwargs) + + self._prev_load = args, kwargs + def __get_desc__(self, **kwargs): legacy = kwargs.pop('legacy', False) shot_ids = kwargs.pop('shot_ids', None) @@ -1210,7 +1278,7 @@ def __get_desc__(self, **kwargs): return description - def __set_desc__(self, description): + def __set_desc__(self, description, **kwargs): if 'shots' in description: shots = description.shots else: @@ -1226,7 +1294,7 @@ def __set_desc__(self, description): self.add(shot) shot = self.get(shot_desc.id) - shot.__set_desc__(shot_desc) + shot.__set_desc__(shot_desc, **kwargs) if 'sequences' in description: sequences = description.sequences @@ -1241,4 +1309,4 @@ def __set_desc__(self, description): self.add_sequence(sequence) sequence = self.get_sequence(seq_desc.id) - sequence.__set_desc__(seq_desc) + sequence.__set_desc__(seq_desc, **kwargs) diff --git a/stride/problem/base.py b/stride/problem/base.py index c6ef3199..8e26100c 100644 --- a/stride/problem/base.py +++ b/stride/problem/base.py @@ -148,7 +148,7 @@ def load(self, *args, **kwargs): with h5.HDF5(*args, **kwargs, mode='r') as file: description = file.load(filter=kwargs.pop('filter', None), only=kwargs.pop('only', None)) - self.__set_desc__(description) + self.__set_desc__(description, **kwargs) def rm(self, *args, **kwargs): """ @@ -166,7 +166,7 @@ def rm(self, *args, **kwargs): def __get_desc__(self, **kwargs): return {} - def __set_desc__(self, description): + def __set_desc__(self, description, **kwargs): pass @@ -270,7 +270,7 @@ def load(self, *args, **kwargs): self._grid.slow_time = slow_time - self.__set_desc__(description) + self.__set_desc__(description, **kwargs) def grid_description(self): """ diff --git a/stride/problem/data.py b/stride/problem/data.py index 65fe13cc..3cb954d6 100644 --- a/stride/problem/data.py +++ b/stride/problem/data.py @@ -711,7 +711,7 @@ def __get_desc__(self, **kwargs): return description - def __set_desc__(self, description): + def __set_desc__(self, description, **kwargs): self._shape = description.shape self._extended_shape = description.extended_shape self._dtype = np.dtype(description.dtype) @@ -1183,7 +1183,7 @@ def __get_desc__(self, **kwargs): return description - def __set_desc__(self, description): + def __set_desc__(self, description, **kwargs): self._shape = description.shape self._dtype = np.dtype(description.dtype) self._time_dependent = description.time_dependent @@ -1375,7 +1375,7 @@ def __get_desc__(self, **kwargs): return description - def __set_desc__(self, description): + def __set_desc__(self, description, **kwargs): self._shape = description.shape self._dtype = np.dtype(description.dtype) self._time_dependent = description.time_dependent @@ -1631,8 +1631,8 @@ def __get_desc__(self, **kwargs): return description - def __set_desc__(self, description): - super().__set_desc__(description) + def __set_desc__(self, description, **kwargs): + super().__set_desc__(description, **kwargs) self._transducer_ids = description.transducer_ids @@ -1814,7 +1814,7 @@ def __get_desc__(self, **kwargs): return description - def __set_desc__(self, description): + def __set_desc__(self, description, **kwargs): self._shape = description.shape self._dtype = np.dtype(description.dtype) self._num = description.num diff --git a/stride/problem/geometry.py b/stride/problem/geometry.py index 4bf0aca3..7975c5af 100644 --- a/stride/problem/geometry.py +++ b/stride/problem/geometry.py @@ -89,7 +89,7 @@ def __get_desc__(self, **kwargs): return description - def __set_desc__(self, description, transducers=None): + def __set_desc__(self, description, transducers=None, **kwargs): self.id = description.id self.transducer = transducers.get(description.transducer_id) if hasattr(description.coordinates, 'load'): @@ -447,7 +447,7 @@ def __get_desc__(self, **kwargs): return description - def __set_desc__(self, description): + def __set_desc__(self, description, **kwargs): locations = description.locations if isinstance(locations, mosaic.types.Struct): locations = locations.values() @@ -468,4 +468,4 @@ def __set_desc__(self, description): self.add_location(instance) instance = self.get(location_desc.id) - instance.__set_desc__(location_desc, self._transducers) + instance.__set_desc__(location_desc, self._transducers, **kwargs) diff --git a/stride/problem/transducer_types/transducer.py b/stride/problem/transducer_types/transducer.py index e7ca6ea0..e7f0b0a4 100644 --- a/stride/problem/transducer_types/transducer.py +++ b/stride/problem/transducer_types/transducer.py @@ -82,5 +82,5 @@ def __get_desc__(self, **kwargs): return description - def __set_desc__(self, description): + def __set_desc__(self, description, **kwargs): self.id = description.id diff --git a/stride/problem/transducers.py b/stride/problem/transducers.py index 90ba1bc9..d52977db 100644 --- a/stride/problem/transducers.py +++ b/stride/problem/transducers.py @@ -206,11 +206,11 @@ def __get_desc__(self, **kwargs): return description - def __set_desc__(self, description): + def __set_desc__(self, description, **kwargs): for transducer_desc in description.transducers: transducer_type = getattr(transducer_types, camel_case(transducer_desc.type)) transducer = transducer_type(transducer_desc.id, grid=self.grid) - transducer.__set_desc__(transducer_desc) + transducer.__set_desc__(transducer_desc, **kwargs) self.add(transducer)