diff --git a/stride/physics/iso_acoustic/devito.py b/stride/physics/iso_acoustic/devito.py index cc53143c..89064ff6 100644 --- a/stride/physics/iso_acoustic/devito.py +++ b/stride/physics/iso_acoustic/devito.py @@ -336,7 +336,7 @@ async def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): else: # If the wavefield is lazily streamed, re-create every time - if save_wavefield and 'nvidia' in platform and devito.pro_available: + if 'nvidia' in platform and devito.pro_available: self.dev_grid.undersampled_time_function('p_saved', bounds=kwargs.pop('save_bounds', None), factor=self.undersampling_factor, @@ -365,6 +365,7 @@ async def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs): rho_with_halo = self.dev_grid.with_halo(rho.extended_data) self.dev_grid.vars.rho.data_with_halo[:] = rho_with_halo + self.dev_grid.vars.buoy.data_with_halo[:] = 1/rho_with_halo if alpha is not None: self.logger.perf('(ShotID %d) Using attenuation with power %d' % (problem.shot_id, self.attenuation_power)) @@ -526,6 +527,7 @@ def _rm_tmpdir(): self.dev_grid.deallocate('rec') self.dev_grid.deallocate('vp') self.dev_grid.deallocate('rho') + self.dev_grid.deallocate('buoy') self.dev_grid.deallocate('alpha', collect=True) return traces @@ -640,6 +642,7 @@ async def before_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=Non if rho is not None: rho_with_halo = self.dev_grid.with_halo(rho.extended_data) self.dev_grid.vars.rho.data_with_halo[:] = rho_with_halo + self.dev_grid.vars.buoy.data_with_halo[:] = 1/rho_with_halo if alpha is not None: db_to_neper = 100 * (1e-6 / (2*np.pi))**self.attenuation_power / (20 * np.log10(np.exp(1))) @@ -732,6 +735,7 @@ async def after_adjoint(self, adjoint_source, wavelets, vp, rho=None, alpha=None self.dev_grid.deallocate('rec') self.dev_grid.deallocate('vp') self.dev_grid.deallocate('rho') + self.dev_grid.deallocate('buoy') self.dev_grid.deallocate('alpha', collect=True) return await self.get_grad(wavelets, vp, rho, alpha, **kwargs) @@ -1219,7 +1223,7 @@ def _medium_functions(self, vp, rho=None, alpha=None, **kwargs): if rho is not None: rho_fun = self.dev_grid.function('rho', **_kwargs) - buoy_fun = 1/rho_fun + buoy_fun = self.dev_grid.function('buoy', **_kwargs) else: rho_fun = buoy_fun = None