Skip to content

Commit

Permalink
Merge pull request #57 from charlesincharge/fix/devito-platform-none
Browse files Browse the repository at this point in the history
Handle `platform=None` in `async def before_forward`
  • Loading branch information
ccuetom authored Oct 10, 2023
2 parents 96ea0da + ee7a334 commit fb47928
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions stride/physics/iso_acoustic/devito.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ async def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs):
diff_source = kwargs.pop('diff_source', False)
save_compression = kwargs.get('save_compression',
'bitcomp' if self.space.dim > 2 else None)
save_compression = save_compression if 'nvidia' in platform and devito.pro_available else None
save_compression = save_compression if platform and 'nvidia' in platform and devito.pro_available else None

# If there's no previous operator, generate one
if self.state_operator.devito_operator is None:
Expand Down Expand Up @@ -304,7 +304,7 @@ async def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs):

# Define the saving of the wavefield
if save_wavefield is True:
layers = devito.HostDevice if 'nvidia' in platform else devito.NoLayers
layers = devito.HostDevice if platform and 'nvidia' in platform else devito.NoLayers
p_saved = self.dev_grid.undersampled_time_function('p_saved',
bounds=kwargs.pop('save_bounds', None),
factor=self.undersampling_factor,
Expand Down Expand Up @@ -336,7 +336,7 @@ async def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs):

else:
# If the wavefield is lazily streamed, re-create every time
if 'nvidia' in platform and devito.pro_available:
if platform and 'nvidia' in platform and devito.pro_available:
self.dev_grid.undersampled_time_function('p_saved',
bounds=kwargs.pop('save_bounds', None),
factor=self.undersampling_factor,
Expand Down

0 comments on commit fb47928

Please sign in to comment.