Skip to content

Commit

Permalink
Fixed marmottant inconsistencies
Browse files Browse the repository at this point in the history
  • Loading branch information
ccuetom committed Aug 1, 2024
1 parent bb20a84 commit 8e410f3
Show file tree
Hide file tree
Showing 6 changed files with 529 additions and 8,066 deletions.
6 changes: 3 additions & 3 deletions stride/physics/common/devito.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def __init__(self, space_order, time_order, time_dim=None, **kwargs):

if space is None:
origin = (0,)
extended_shape = (1,)
extended_shape = (2,)
extended_extent = (1,)
else:
extra = space.absorbing
Expand Down Expand Up @@ -286,7 +286,7 @@ def __init__(self, space_order, time_order, time_dim=None, **kwargs):
if parent_grid is not None:
dimensions = parent_grid.dimensions
time_dimension = devito.TimeDimension(name='time_inner',
spacing=devito.types.Scalar(name='dt_inner', is_const=True))
spacing=devito.Scalar(name='dt_inner', is_const=True))
self.num_inner = kwargs.pop('num_inner', 1)
else:
self.num_inner = None
Expand Down Expand Up @@ -691,7 +691,7 @@ def sparse_function(self, name, num=1, space_order=None,

reference_gridpoints, coefficients = self._calculate_hicks(coordinates)

fun = devito.PrecomputedSparseFunction(r=r,
fun = devito.PrecomputedSparseFunction(r=r+1,
gridpoints=reference_gridpoints,
interpolation_coeffs=coefficients,
**sparse_kwargs)
Expand Down
3 changes: 2 additions & 1 deletion stride/physics/common/import_devito.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@

from devito import * # noqa: F401
from devito.types import Symbol # noqa: F401
from devito.types import Symbol, Scalar # noqa: F401
from devito.symbolics import INT, IntDiv # noqa: F401
from devito import TimeFunction as TimeFunctionOSS # noqa: F401

try:
from devitopro import * # noqa: F401
Expand Down
7 changes: 4 additions & 3 deletions stride/physics/iso_acoustic/devito.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def before_forward(self, wavelets, vp, rho=None, alpha=None, **kwargs):
# ValueError: Cannot access `shape_allocated` as unfinalized - so no size estimate
pass

if self._needs_grad(wavelets, rho, alpha):
if self._needs_grad(wavelets, rho, alpha, **kwargs):
p_saved_expr = p
else:
p_saved_expr = self._forward_save(p)
Expand Down Expand Up @@ -1595,8 +1595,9 @@ def _weights(self):
def _dt_max(self, k, h, vp_max):
return k * h / vp_max * 1 / np.sqrt(self.space.dim)

def _needs_grad(self, *wrt):
return any(v is not None and v.needs_grad for v in wrt)
def _needs_grad(self, *wrt, **kwargs):
force_raw_wavefield = kwargs.pop('force_raw_wavefield', False)
return any(v is not None and v.needs_grad for v in wrt) or force_raw_wavefield

def _forward_save(self, field):
return field.dt2
Expand Down
10 changes: 7 additions & 3 deletions stride/physics/marmottant/devito.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ def run_forward(self, *args, **kwargs):
"""
functions = dict()

time_bounds = kwargs.get('time_bounds', (0, self.time.extended_num))
self.state_operator.run(dt=self.time.step,
time_m=1,
time_M=time_bounds[1]-1,
**functions,
**kwargs.pop('devito_args', {}))

Expand Down Expand Up @@ -286,7 +289,7 @@ def sub_stencil(self, **kwargs):
rho = parent_grid.vars.rho
except AttributeError:
rho = self.dev_grid.vars.rho_sparse
vp2 = parent_grid.vars.vp2
vp2 = parent_grid.vars.vp**2

inject_term = v_inject.inject(field=p_out.forward, expr=vp2 * self.time.step**2 * rho * inject_scale * v_inject)

Expand Down Expand Up @@ -481,6 +484,7 @@ def _make_saved_time_function(self, name, num, **kwargs):
shape=(self.time.num, num),
space_order=self.space_order,
time_order=self.time_order,
layers=devito.NoLayers,
**kwargs)

def _make_interp_function(self, name, value, x_0, num, **kwargs):
Expand Down Expand Up @@ -519,7 +523,7 @@ def _make_interp_function(self, name, value, x_0, num, **kwargs):
return fun, dense_fun, interp_term

def _make_interp_time_function(self, name, value, x_0, num, **kwargs):
if not isinstance(value, devito.TimeFunction):
if not isinstance(value, devito.TimeFunctionOSS):
fun = self._make_saved_time_function(name, num=num, save=self.time.num)
fun.data[:] = value.data.T

Expand All @@ -541,7 +545,7 @@ def _make_interp_time_function(self, name, value, x_0, num, **kwargs):

dense_fun = value

interp_term = fun.interpolate(expr=dense_fun.forward)
interp_term = fun.interpolate(expr=dense_fun)

if x_0 is None:
raise ValueError('Bubble location x_0 needs to be provided when'
Expand Down
Loading

0 comments on commit 8e410f3

Please sign in to comment.