Skip to content

Commit

Permalink
Added lazy loading functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
ccuetom committed Oct 25, 2024
1 parent 03fd31a commit 781d182
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 33 deletions.
9 changes: 9 additions & 0 deletions stride/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa
select_shots : dict, optional
Rules for selecting available shots per iteration, defaults to taking all shots. For
details on this see :func:`~stride.problem.acquisitions.Acquisitions.select_shot_ids`.
lazy_loading : bool, optional
Whether to load shot data every iteration to save memory.
dump : bool, optional
Whether or not to save to disk the updated variable after every iteration.
f_min : float, optional
Expand Down Expand Up @@ -216,6 +218,7 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa
restart = kwargs.pop('restart', None)
restart_id = kwargs.pop('restart_id', -1)

lazy_loading = kwargs.pop('lazy_loading', False)
dump = kwargs.pop('dump', True)
safe = kwargs.pop('safe', True)

Expand Down Expand Up @@ -290,6 +293,9 @@ async def adjoint(problem, pde, loss, optimisation_loop, optimiser, *args, **kwa
shot_ids = problem.acquisitions.select_shot_ids(**select_shots)
num_shots = len(shot_ids)

if lazy_loading:
problem.acquisitions.load(shot_ids=shot_ids, lazy_loading=False)

@runtime.async_for(shot_ids, safe=safe)
async def loop(worker, shot_id):
_kwargs = kwargs.copy()
Expand Down Expand Up @@ -482,4 +488,7 @@ async def loop(worker, shot_id):
optimisation_loop.num_blocks, iteration.total_loss, prev_loss))
logger.perf('====================================================================')

if lazy_loading:
problem.acquisitions.deallocate(shot_ids=shot_ids)

iteration.clear_run()
14 changes: 7 additions & 7 deletions stride/optimisation/optimisation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __get_desc__(self, **kwargs):

return description

def __set_desc__(self, description):
def __set_desc__(self, description, **kwargs):
self.id = description.id
self.submitted_shots = description.submitted_shots
self.completed_shots = description.completed_shots
Expand Down Expand Up @@ -252,15 +252,15 @@ def __get_desc__(self, **kwargs):

return description

def __set_desc__(self, description):
def __set_desc__(self, description, **kwargs):
self.id = description.id
self.abs_id = description.abs_id

self._curr_run_idx = -1
for run_desc in description.runs:
self._curr_run_idx += 1
run = IterationRun(self._curr_run_idx, self)
run.__set_desc__(run_desc)
run.__set_desc__(run_desc, **kwargs)
self._runs[self._curr_run_idx] = run

_serialisation_attrs = ['id', 'abs_id']
Expand Down Expand Up @@ -453,14 +453,14 @@ def __get_desc__(self, **kwargs):

return description

def __set_desc__(self, description):
def __set_desc__(self, description, **kwargs):
self.id = description.id
self._num_iterations = description.num_iterations

for iter_desc in description.iterations:
iteration = Iteration(iter_desc.id, iter_desc.abs_id,
self, self._optimisation_loop)
iteration.__set_desc__(iter_desc)
iteration.__set_desc__(iter_desc, **kwargs)
self._iterations[iteration.id] = iteration

self._current_iteration = self._iterations[description.current_iteration.id]
Expand Down Expand Up @@ -671,13 +671,13 @@ def __get_desc__(self, **kwargs):

return description

def __set_desc__(self, description):
def __set_desc__(self, description, **kwargs):
self.running_id = description.running_id
self._num_blocks = description.num_blocks

for block_desc in description.blocks:
block = Block(block_desc.id, self)
block.__set_desc__(block_desc)
block.__set_desc__(block_desc, **kwargs)
self._blocks[block.id] = block

self._current_block = self._blocks[description.current_block.id]
90 changes: 79 additions & 11 deletions stride/problem/acquisitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,23 @@ def sub_problem(self, shot, sub_problem):

return shot

def deallocate(self):
"""
Deallocate memory associated with shot.
Returns
-------
"""
if self.wavelets is not None:
self.wavelets.deallocate()

if self.observed is not None:
self.observed.deallocate()

if self.delays is not None:
self.delays.deallocate()

def plot(self, **kwargs):
"""
Plot wavelets and observed for this shot if they are allocated.
Expand Down Expand Up @@ -423,7 +440,7 @@ def __get_desc__(self, **kwargs):

return description

def __set_desc__(self, description):
def __set_desc__(self, description, **kwargs):
self.id = description.id

for source_id in description.source_ids:
Expand All @@ -440,17 +457,19 @@ def __set_desc__(self, description):
receiver = None
self._receivers[receiver_id] = receiver

lazy_loading = kwargs.pop('lazy_loading', False)

self.wavelets = Traces(name='wavelets', transducer_ids=self.source_ids, grid=self.grid)
if 'wavelets' in description:
self.wavelets.__set_desc__(description.wavelets)
if 'wavelets' in description and not lazy_loading:
self.wavelets.__set_desc__(description.wavelets, **kwargs)

self.observed = Traces(name='observed', transducer_ids=self.receiver_ids, grid=self.grid)
if 'observed' in description:
self.observed.__set_desc__(description.observed)
if 'observed' in description and not lazy_loading:
self.observed.__set_desc__(description.observed, **kwargs)

self.delays = Traces(name='delays', transducer_ids=self.source_ids, shape=(len(self.source_ids), 1), grid=self.grid)
if 'delays' in description:
self.delays.__set_desc__(description.delays)
if 'delays' in description and not lazy_loading:
self.delays.__set_desc__(description.delays, **kwargs)


class Sequence(ProblemBase):
Expand Down Expand Up @@ -657,7 +676,7 @@ def __get_desc__(self, **kwargs):

return description

def __set_desc__(self, description):
def __set_desc__(self, description, **kwargs):
self.id = description.id
self.acq = description.acq

Expand Down Expand Up @@ -707,6 +726,7 @@ def __init__(self, name='acquisitions', problem=None, **kwargs):
self._sequences = OrderedDict()
self._shot_selection = []
self._sequence_selection = []
self._prev_load = None, None

@property
def shots(self):
Expand Down Expand Up @@ -1026,6 +1046,25 @@ def default(self):
sources=[source], receivers=receivers,
geometry=self._geometry, problem=self.problem))

def deallocate(self, shot_ids=None):
"""
Deallocate memory associated with shots.
Parameters
----------
shot_ids : list, optional
Set of shot IDs to deallocate.
Returns
-------
"""
if shot_ids is None:
shot_ids = self.shot_ids

for shot_id in shot_ids:
self._shots[shot_id].deallocate()

def plot(self, **kwargs):
"""
Plot wavelets and observed for for all shots if they are allocated.
Expand Down Expand Up @@ -1172,6 +1211,35 @@ def sub_problem(self, shot, sub_problem):

return sub_acquisitions

def load(self, *args, **kwargs):
"""
Load the object using ``__set_desc__`` to digest the description.
See :class:`~mosaic.file_manipulation.h5.HDF5` for more information on the parameters of this method.
Parameters
----------
shot_ids : list, optional
List of shot IDs to load.
Returns
-------
"""
shot_ids = kwargs.pop('shot_ids', None)

prev_args, prev_kwargs = self._prev_load
if prev_args is not None:
args = args + prev_args[len(args):] if len(args) < len(prev_args) else args
if prev_kwargs is not None:
kwargs_ = prev_kwargs.copy()
kwargs_.update(kwargs)
kwargs = kwargs_

super().load(*args, filter={'shots': shot_ids} if shot_ids is not None else None, **kwargs)

self._prev_load = args, kwargs

def __get_desc__(self, **kwargs):
legacy = kwargs.pop('legacy', False)
shot_ids = kwargs.pop('shot_ids', None)
Expand Down Expand Up @@ -1210,7 +1278,7 @@ def __get_desc__(self, **kwargs):

return description

def __set_desc__(self, description):
def __set_desc__(self, description, **kwargs):
if 'shots' in description:
shots = description.shots
else:
Expand All @@ -1226,7 +1294,7 @@ def __set_desc__(self, description):
self.add(shot)

shot = self.get(shot_desc.id)
shot.__set_desc__(shot_desc)
shot.__set_desc__(shot_desc, **kwargs)

if 'sequences' in description:
sequences = description.sequences
Expand All @@ -1241,4 +1309,4 @@ def __set_desc__(self, description):
self.add_sequence(sequence)

sequence = self.get_sequence(seq_desc.id)
sequence.__set_desc__(seq_desc)
sequence.__set_desc__(seq_desc, **kwargs)
6 changes: 3 additions & 3 deletions stride/problem/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def load(self, *args, **kwargs):
with h5.HDF5(*args, **kwargs, mode='r') as file:
description = file.load(filter=kwargs.pop('filter', None), only=kwargs.pop('only', None))

self.__set_desc__(description)
self.__set_desc__(description, **kwargs)

def rm(self, *args, **kwargs):
"""
Expand All @@ -166,7 +166,7 @@ def rm(self, *args, **kwargs):
def __get_desc__(self, **kwargs):
return {}

def __set_desc__(self, description):
def __set_desc__(self, description, **kwargs):
pass


Expand Down Expand Up @@ -270,7 +270,7 @@ def load(self, *args, **kwargs):

self._grid.slow_time = slow_time

self.__set_desc__(description)
self.__set_desc__(description, **kwargs)

def grid_description(self):
"""
Expand Down
12 changes: 6 additions & 6 deletions stride/problem/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def __get_desc__(self, **kwargs):

return description

def __set_desc__(self, description):
def __set_desc__(self, description, **kwargs):
self._shape = description.shape
self._extended_shape = description.extended_shape
self._dtype = np.dtype(description.dtype)
Expand Down Expand Up @@ -1183,7 +1183,7 @@ def __get_desc__(self, **kwargs):

return description

def __set_desc__(self, description):
def __set_desc__(self, description, **kwargs):
self._shape = description.shape
self._dtype = np.dtype(description.dtype)
self._time_dependent = description.time_dependent
Expand Down Expand Up @@ -1375,7 +1375,7 @@ def __get_desc__(self, **kwargs):

return description

def __set_desc__(self, description):
def __set_desc__(self, description, **kwargs):
self._shape = description.shape
self._dtype = np.dtype(description.dtype)
self._time_dependent = description.time_dependent
Expand Down Expand Up @@ -1631,8 +1631,8 @@ def __get_desc__(self, **kwargs):

return description

def __set_desc__(self, description):
super().__set_desc__(description)
def __set_desc__(self, description, **kwargs):
super().__set_desc__(description, **kwargs)

self._transducer_ids = description.transducer_ids

Expand Down Expand Up @@ -1814,7 +1814,7 @@ def __get_desc__(self, **kwargs):

return description

def __set_desc__(self, description):
def __set_desc__(self, description, **kwargs):
self._shape = description.shape
self._dtype = np.dtype(description.dtype)
self._num = description.num
Expand Down
6 changes: 3 additions & 3 deletions stride/problem/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __get_desc__(self, **kwargs):

return description

def __set_desc__(self, description, transducers=None):
def __set_desc__(self, description, transducers=None, **kwargs):
self.id = description.id
self.transducer = transducers.get(description.transducer_id)
if hasattr(description.coordinates, 'load'):
Expand Down Expand Up @@ -447,7 +447,7 @@ def __get_desc__(self, **kwargs):

return description

def __set_desc__(self, description):
def __set_desc__(self, description, **kwargs):
locations = description.locations
if isinstance(locations, mosaic.types.Struct):
locations = locations.values()
Expand All @@ -468,4 +468,4 @@ def __set_desc__(self, description):
self.add_location(instance)

instance = self.get(location_desc.id)
instance.__set_desc__(location_desc, self._transducers)
instance.__set_desc__(location_desc, self._transducers, **kwargs)
2 changes: 1 addition & 1 deletion stride/problem/transducer_types/transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,5 @@ def __get_desc__(self, **kwargs):

return description

def __set_desc__(self, description):
def __set_desc__(self, description, **kwargs):
self.id = description.id
4 changes: 2 additions & 2 deletions stride/problem/transducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,11 @@ def __get_desc__(self, **kwargs):

return description

def __set_desc__(self, description):
def __set_desc__(self, description, **kwargs):
for transducer_desc in description.transducers:
transducer_type = getattr(transducer_types, camel_case(transducer_desc.type))
transducer = transducer_type(transducer_desc.id, grid=self.grid)

transducer.__set_desc__(transducer_desc)
transducer.__set_desc__(transducer_desc, **kwargs)

self.add(transducer)

0 comments on commit 781d182

Please sign in to comment.