Skip to content

Commit

Permalink
Use separate operators for saved/non-saved
Browse files Browse the repository at this point in the history
  • Loading branch information
ccuetom committed May 8, 2024
1 parent 57af14f commit 2415d7b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 11 deletions.
2 changes: 1 addition & 1 deletion stride/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ async def loop(worker, shot_id):
runtime=worker, **_kwargs).result()

# clear up
await pde.deallocate_wavefield(deallocate=True, runtime=worker, **_kwargs)
# await pde.deallocate_wavefield(deallocate=True, runtime=worker, **_kwargs)
fun.clear_graph()

iteration.add_loss(fun)
Expand Down
45 changes: 35 additions & 10 deletions stride/physics/iso_acoustic/devito.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def __init__(self, **kwargs):
name='acoustic_iso_state',
grid=self.dev_grid,
**kwargs)
self.state_operator_save = OperatorDevito(self.space_order, self.time_order,
name='acoustic_iso_state_save',
grid=self.dev_grid,
**kwargs)
self.adjoint_operator = OperatorDevito(self.space_order, self.time_order,
name='acoustic_iso_adjoint',
grid=self.dev_grid,
Expand All @@ -164,11 +168,13 @@ def __init__(self, **kwargs):
if self._cached_operator:
warehouse['%s_dev_grid' % cached_name] = self.dev_grid
warehouse['%s_state_operator' % cached_name] = self.state_operator
warehouse['%s_state_operator_save' % cached_name] = self.state_operator_save
warehouse['%s_adjoint_operator' % cached_name] = self.adjoint_operator

else:
self.dev_grid = warehouse['%s_dev_grid' % cached_name]
self.state_operator = warehouse['%s_state_operator' % cached_name]
self.state_operator_save = warehouse['%s_state_operator_save' % cached_name]
self.adjoint_operator = warehouse['%s_adjoint_operator' % cached_name]

self.boundary = None
Expand All @@ -182,6 +188,7 @@ def __init__(self, **kwargs):

def clear_operators(self):
self.state_operator.devito_operator = None
self.state_operator_save.devito_operator = None
self.adjoint_operator.devito_operator = None

def deallocate_wavefield(self, platform='cpu', deallocate=False, **kwargs):
Expand Down Expand Up @@ -407,9 +414,13 @@ async def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs):
if self.attenuation_power == 2:
kwargs['devito_config']['opt'] = 'noop'

self.state_operator.set_operator(stencil + src_term + rec_term + update_saved,
self.state_operator.set_operator(stencil + src_term + rec_term,
**kwargs)
self.state_operator.compile()
if save_wavefield is True:
self.state_operator_save.set_operator(stencil + src_term + rec_term + update_saved,
**kwargs)
self.state_operator_save.compile()

else:
# If the source/receiver size has changed, then create new functions for them
Expand Down Expand Up @@ -492,6 +503,19 @@ async def run_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs):
-------
"""
problem = kwargs.get('problem')
shot = problem.shot

dump_forward_wavefield = kwargs.pop('dump_forward_wavefield', False)
dump_wavefield_id = kwargs.pop('dump_wavefield_id', shot.id)
save_wavefield = kwargs.pop('save_wavefield', bool(dump_forward_wavefield) and dump_wavefield_id == shot.id)
if save_wavefield is False:
save_wavefield = vp.needs_grad
if rho is not None:
save_wavefield |= rho.needs_grad
if alpha is not None:
save_wavefield |= alpha.needs_grad

functions = dict(
vp=self.dev_grid.vars.vp,
src=self.dev_grid.vars.src,
Expand All @@ -500,7 +524,7 @@ async def run_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs):

devito_args = kwargs.get('devito_args', {})

if 'p_saved' in self.dev_grid.vars:
if 'p_saved' in self.dev_grid.vars and save_wavefield:
if self._wavefield is None:
self._wavefield = self.dev_grid.func('p_saved')

Expand All @@ -510,17 +534,21 @@ async def run_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs):
devito_args['nbits'] = kwargs.get('nbits_compression',
devito_args.get('nbits', 9))

op = self.state_operator_save
else:
op = self.state_operator

if np.linalg.norm(wavelets.data) < 1e-31:
problem = kwargs.pop('problem')
self.logger.warn('(ShotID %d) Empty wavelets, not running forward' % problem.shot_id)
return

time_bounds = kwargs.get('time_bounds', (0, self.time.extended_num))
self.state_operator.run(dt=self.time.step,
time_m=1,
time_M=time_bounds[1]-1,
**functions,
**devito_args)
op.run(dt=self.time.step,
time_m=1,
time_M=time_bounds[1]-1,
**functions,
**devito_args)

async def after_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs):
"""
Expand Down Expand Up @@ -609,9 +637,6 @@ def _rm_tmpdir():

self.dev_grid.deallocate('p_saved')

else:
self._wavefield = None

traces_data = np.asarray(self.dev_grid.vars.rec.data, dtype=np.float32).T
traces = shot.observed.alike(name='modelled', data=traces_data, shape=None, extended_shape=None, inner=None)

Expand Down

0 comments on commit 2415d7b

Please sign in to comment.