Skip to content

Commit

Permalink
api: process injected expression dimensions in case it's not the spar…
Browse files Browse the repository at this point in the history
…se function
  • Loading branch information
mloubout committed Sep 15, 2023
1 parent d715f3f commit 7df78f7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
18 changes: 11 additions & 7 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,12 @@ def _rdim(self):

return DimensionTuple(*rdims, getters=self._gdims)

def _augment_implicit_dims(self, implicit_dims):
def _augment_implicit_dims(self, implicit_dims, extras=None):
extra = tuple([i for v in extras for i in v.dimensions])
if self.sfunction._sparse_position == -1:
return self.sfunction.dimensions + as_tuple(implicit_dims)
return self.sfunction.dimensions + as_tuple(implicit_dims) + extra
else:
return as_tuple(implicit_dims) + self.sfunction.dimensions
return as_tuple(implicit_dims) + self.sfunction.dimensions + extra

def _coeff_temps(self, implicit_dims):
return []
Expand Down Expand Up @@ -252,8 +253,6 @@ def _interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):
interpolation expression, but that should be honored when constructing
the operator.
"""
implicit_dims = self._augment_implicit_dims(implicit_dims)

# Derivatives must be evaluated before the introduction of indirect accesses
try:
_expr = expr.evaluate
Expand All @@ -263,6 +262,9 @@ def _interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):

variables = list(retrieve_function_carriers(_expr))

# Implicit dimensions
implicit_dims = self._augment_implicit_dims(implicit_dims)

# List of indirection indices for all adjacent grid points
idx_subs, temps = self._interp_idx(variables, implicit_dims=implicit_dims)

Expand Down Expand Up @@ -295,8 +297,6 @@ def _inject(self, field, expr, implicit_dims=None):
injection expression, but that should be honored when constructing
the operator.
"""
implicit_dims = self._augment_implicit_dims(implicit_dims)

# Make iterable to support inject((u, v), expr=expr)
# or inject((u, v), expr=(expr1, expr2))
fields, exprs = as_tuple(field), as_tuple(expr)
Expand All @@ -315,6 +315,10 @@ def _inject(self, field, expr, implicit_dims=None):
_exprs = exprs

variables = list(v for e in _exprs for v in retrieve_function_carriers(e))

# Implicit dimensions
implicit_dims = self._augment_implicit_dims(implicit_dims, variables)

variables = variables + list(fields)

# List of indirection indices for all adjacent grid points
Expand Down
25 changes: 25 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from sympy import Float

from conftest import assert_structure
from devito import (Grid, Operator, Dimension, SparseFunction, SparseTimeFunction,
Function, TimeFunction, DefaultDimension, Eq,
PrecomputedSparseFunction, PrecomputedSparseTimeFunction,
Expand Down Expand Up @@ -734,3 +735,27 @@ class SparseFirst(SparseFunction):
op(time_M=10)
expected = 10*11/2 # n (n+1)/2
assert np.allclose(s.data, expected)


def test_inject_function():
nt = 11

grid = Grid(shape=(5, 5))
u = TimeFunction(name="u", grid=grid, time_order=2)
src = SparseTimeFunction(name="src", grid=grid, nt=nt, npoint=1,
coordinates=[[0.5, 0.5]])

nfreq = 5
freq_dim = DefaultDimension(name="freq", default_value=nfreq)
omega = Function(name="omega", dimensions=(freq_dim,), shape=(nfreq,), grid=grid)
omega.data.fill(1.)

inj = src.inject(field=u.forward, expr=omega)

op = Operator([inj])

assert_structure(op, ['p_src', 't', 't,p_src,freq', 't,p_src,freq,rsrcx,rsrcy'],
'p_src,t,p_src,freq,rsrcx,rsrcy')

op(time_M=0)
assert u.data[1, 2, 2] == nfreq

0 comments on commit 7df78f7

Please sign in to comment.