From 2415d7b8a116523c06e2e7405b409f34fe417807 Mon Sep 17 00:00:00 2001 From: Carlos Cueto Date: Wed, 8 May 2024 12:49:18 +0100 Subject: [PATCH] Use separate operators for saved/non-saved --- stride/__init__.py | 2 +- stride/physics/iso_acoustic/devito.py | 45 +++++++++++++++++++++------ 2 files changed, 36 insertions(+), 11 deletions(-) diff --git a/stride/__init__.py b/stride/__init__.py index 4275fca..f8e4178 100644 --- a/stride/__init__.py +++ b/stride/__init__.py @@ -446,7 +446,7 @@ async def loop(worker, shot_id): runtime=worker, **_kwargs).result() # clear up - await pde.deallocate_wavefield(deallocate=True, runtime=worker, **_kwargs) + # await pde.deallocate_wavefield(deallocate=True, runtime=worker, **_kwargs) fun.clear_graph() iteration.add_loss(fun) diff --git a/stride/physics/iso_acoustic/devito.py b/stride/physics/iso_acoustic/devito.py index db65bec..0136f1d 100644 --- a/stride/physics/iso_acoustic/devito.py +++ b/stride/physics/iso_acoustic/devito.py @@ -156,6 +156,10 @@ def __init__(self, **kwargs): name='acoustic_iso_state', grid=self.dev_grid, **kwargs) + self.state_operator_save = OperatorDevito(self.space_order, self.time_order, + name='acoustic_iso_state_save', + grid=self.dev_grid, + **kwargs) self.adjoint_operator = OperatorDevito(self.space_order, self.time_order, name='acoustic_iso_adjoint', grid=self.dev_grid, @@ -164,11 +168,13 @@ def __init__(self, **kwargs): if self._cached_operator: warehouse['%s_dev_grid' % cached_name] = self.dev_grid warehouse['%s_state_operator' % cached_name] = self.state_operator + warehouse['%s_state_operator_save' % cached_name] = self.state_operator_save warehouse['%s_adjoint_operator' % cached_name] = self.adjoint_operator else: self.dev_grid = warehouse['%s_dev_grid' % cached_name] self.state_operator = warehouse['%s_state_operator' % cached_name] + self.state_operator_save = warehouse['%s_state_operator_save' % cached_name] self.adjoint_operator = warehouse['%s_adjoint_operator' % cached_name] self.boundary = None @@ -182,6 +188,7 @@ def __init__(self, **kwargs): def clear_operators(self): self.state_operator.devito_operator = None + self.state_operator_save.devito_operator = None self.adjoint_operator.devito_operator = None def deallocate_wavefield(self, platform='cpu', deallocate=False, **kwargs): @@ -407,9 +414,13 @@ async def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): if self.attenuation_power == 2: kwargs['devito_config']['opt'] = 'noop' - self.state_operator.set_operator(stencil + src_term + rec_term + update_saved, + self.state_operator.set_operator(stencil + src_term + rec_term, **kwargs) self.state_operator.compile() + if save_wavefield is True: + self.state_operator_save.set_operator(stencil + src_term + rec_term + update_saved, + **kwargs) + self.state_operator_save.compile() else: # If the source/receiver size has changed, then create new functions for them @@ -492,6 +503,19 @@ async def run_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): ------- """ + problem = kwargs.get('problem') + shot = problem.shot + + dump_forward_wavefield = kwargs.pop('dump_forward_wavefield', False) + dump_wavefield_id = kwargs.pop('dump_wavefield_id', shot.id) + save_wavefield = kwargs.pop('save_wavefield', bool(dump_forward_wavefield) and dump_wavefield_id == shot.id) + if save_wavefield is False: + save_wavefield = vp.needs_grad + if rho is not None: + save_wavefield |= rho.needs_grad + if alpha is not None: + save_wavefield |= alpha.needs_grad + functions = dict( vp=self.dev_grid.vars.vp, src=self.dev_grid.vars.src, @@ -500,7 +524,7 @@ async def run_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): devito_args = kwargs.get('devito_args', {}) - if 'p_saved' in self.dev_grid.vars: + if 'p_saved' in self.dev_grid.vars and save_wavefield: if self._wavefield is None: self._wavefield = self.dev_grid.func('p_saved') @@ -510,17 +534,21 @@ async def run_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): devito_args['nbits'] = kwargs.get('nbits_compression', devito_args.get('nbits', 9)) + op = self.state_operator_save + else: + op = self.state_operator + if np.linalg.norm(wavelets.data) < 1e-31: problem = kwargs.pop('problem') self.logger.warn('(ShotID %d) Empty wavelets, not running forward' % problem.shot_id) return time_bounds = kwargs.get('time_bounds', (0, self.time.extended_num)) - self.state_operator.run(dt=self.time.step, - time_m=1, - time_M=time_bounds[1]-1, - **functions, - **devito_args) + op.run(dt=self.time.step, + time_m=1, + time_M=time_bounds[1]-1, + **functions, + **devito_args) async def after_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): """ @@ -609,9 +637,6 @@ def _rm_tmpdir(): self.dev_grid.deallocate('p_saved') - else: - self._wavefield = None - traces_data = np.asarray(self.dev_grid.vars.rec.data, dtype=np.float32).T traces = shot.observed.alike(name='modelled', data=traces_data, shape=None, extended_shape=None, inner=None)