Skip to content

Commit

Permalink
Merge pull request #56 from trustimaging/fix-buoy
Browse files Browse the repository at this point in the history
Restore buoy instead of 1/rho
  • Loading branch information
ccuetom authored Oct 5, 2023
2 parents 25c880c + b7a6285 commit 96ea0da
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions stride/physics/iso_acoustic/devito.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ async def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs):

else:
# If the wavefield is lazily streamed, re-create every time
if save_wavefield and 'nvidia' in platform and devito.pro_available:
if 'nvidia' in platform and devito.pro_available:
self.dev_grid.undersampled_time_function('p_saved',
bounds=kwargs.pop('save_bounds', None),
factor=self.undersampling_factor,
Expand Down Expand Up @@ -365,6 +365,7 @@ async def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs):

rho_with_halo = self.dev_grid.with_halo(rho.extended_data)
self.dev_grid.vars.rho.data_with_halo[:] = rho_with_halo
self.dev_grid.vars.buoy.data_with_halo[:] = 1/rho_with_halo

if alpha is not None:
self.logger.perf('(ShotID %d) Using attenuation with power %d' % (problem.shot_id, self.attenuation_power))
Expand Down Expand Up @@ -526,6 +527,7 @@ def _rm_tmpdir():
self.dev_grid.deallocate('rec')
self.dev_grid.deallocate('vp')
self.dev_grid.deallocate('rho')
self.dev_grid.deallocate('buoy')
self.dev_grid.deallocate('alpha', collect=True)

return traces
Expand Down Expand Up @@ -640,6 +642,7 @@ async def before_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=Non
if rho is not None:
rho_with_halo = self.dev_grid.with_halo(rho.extended_data)
self.dev_grid.vars.rho.data_with_halo[:] = rho_with_halo
self.dev_grid.vars.buoy.data_with_halo[:] = 1/rho_with_halo

if alpha is not None:
db_to_neper = 100 * (1e-6 / (2*np.pi))**self.attenuation_power / (20 * np.log10(np.exp(1)))
Expand Down Expand Up @@ -732,6 +735,7 @@ async def after_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=None
self.dev_grid.deallocate('rec')
self.dev_grid.deallocate('vp')
self.dev_grid.deallocate('rho')
self.dev_grid.deallocate('buoy')
self.dev_grid.deallocate('alpha', collect=True)

return await self.get_grad(wavelets, vp, rho, alpha, **kwargs)
Expand Down Expand Up @@ -1219,7 +1223,7 @@ def _medium_functions(self, vp, rho=None, alpha=None, **kwargs):

if rho is not None:
rho_fun = self.dev_grid.function('rho', **_kwargs)
buoy_fun = 1/rho_fun
buoy_fun = self.dev_grid.function('buoy', **_kwargs)
else:
rho_fun = buoy_fun = None

Expand Down

0 comments on commit 96ea0da

Please sign in to comment.