Skip to content

Commit

Permalink
api: process injected expression dimensions in case it's not the spar…
Browse files Browse the repository at this point in the history
…se function
  • Loading branch information
mloubout committed Sep 15, 2023
1 parent 76a3c4b commit 6212a00
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 18 deletions.
7 changes: 6 additions & 1 deletion devito/builtins/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs):
symbolic_max=d.symbolic_max + h.right)
eqs = [eq.xreplace(subs) for eq in eqs]

dv.Operator(eqs, name=name, **kwargs)()
op = dv.Operator(eqs, name=name, **kwargs)
try:
op()
except ValueError:
# Corner case such as assign(u, v) with v a Buffered TimeFunction
op(time_M=f._time_size)


def smooth(f, g, axis=None):
Expand Down
23 changes: 16 additions & 7 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,17 @@ def _rdim(self):

return DimensionTuple(*rdims, getters=self._gdims)

def _augment_implicit_dims(self, implicit_dims):
def _augment_implicit_dims(self, implicit_dims, extras=None):
if extras is not None:
extra = set([i for v in extras for i in v.dimensions]) - set(self._gdims)
extra = tuple(extra)
else:
extra = tuple()

if self.sfunction._sparse_position == -1:
return self.sfunction.dimensions + as_tuple(implicit_dims)
return self.sfunction.dimensions + as_tuple(implicit_dims) + extra
else:
return as_tuple(implicit_dims) + self.sfunction.dimensions
return as_tuple(implicit_dims) + self.sfunction.dimensions + extra

def _coeff_temps(self, implicit_dims):
return []
Expand Down Expand Up @@ -252,8 +258,6 @@ def _interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):
interpolation expression, but that should be honored when constructing
the operator.
"""
implicit_dims = self._augment_implicit_dims(implicit_dims)

# Derivatives must be evaluated before the introduction of indirect accesses
try:
_expr = expr.evaluate
Expand All @@ -263,6 +267,9 @@ def _interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):

variables = list(retrieve_function_carriers(_expr))

# Implicit dimensions
implicit_dims = self._augment_implicit_dims(implicit_dims)

# List of indirection indices for all adjacent grid points
idx_subs, temps = self._interp_idx(variables, implicit_dims=implicit_dims)

Expand Down Expand Up @@ -295,8 +302,6 @@ def _inject(self, field, expr, implicit_dims=None):
injection expression, but that should be honored when constructing
the operator.
"""
implicit_dims = self._augment_implicit_dims(implicit_dims)

# Make iterable to support inject((u, v), expr=expr)
# or inject((u, v), expr=(expr1, expr2))
fields, exprs = as_tuple(field), as_tuple(expr)
Expand All @@ -315,6 +320,10 @@ def _inject(self, field, expr, implicit_dims=None):
_exprs = exprs

variables = list(v for e in _exprs for v in retrieve_function_carriers(e))

# Implicit dimensions
implicit_dims = self._augment_implicit_dims(implicit_dims, variables)

variables = variables + list(fields)

# List of indirection indices for all adjacent grid points
Expand Down
7 changes: 7 additions & 0 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,13 +566,17 @@ def _prepare_arguments(self, autotune=None, **kwargs):
"`%s=%s`, while `%s=%s` is expected. Perhaps you "
"forgot to override `%s`?" %
(p, k, v, k, args[k], p))

args = kwargs['args'] = args.reduce_all()

# DiscreteFunctions may be created from CartesianDiscretizations, which in
# turn could be Grids or SubDomains. Both may provide arguments
discretizations = {getattr(kwargs[p.name], 'grid', None) for p in overrides}
discretizations.update({getattr(p, 'grid', None) for p in defaults})
discretizations.discard(None)
# Remove subgrids
discretizations = {g for g in discretizations
if not any(d.is_Derived for d in g.dimensions)}
for i in discretizations:
args.update(i._arg_values(**kwargs))

Expand All @@ -585,6 +589,9 @@ def _prepare_arguments(self, autotune=None, **kwargs):
if configuration['mpi']:
raise ValueError("Multiple Grids found")
try:
# Take biggest grid, i.e discard grids with subdimensions
grids = {g for g in grids if not any(d.is_Sub for d in g.dimensions)}
# First grid as there is no heuristic on how to choose from the leftovers
grid = grids.pop()
except KeyError:
grid = None
Expand Down
3 changes: 3 additions & 0 deletions devito/tools/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def compare_to_first(v):
return candidates[0]
elif all(map(compare_to_first, candidates)):
# return first non-range
for c in candidates:
if not isinstance(c, range):
return c
return candidates[0]
else:
raise ValueError("Unable to find unique value for key %s, candidates: %s"
Expand Down
15 changes: 6 additions & 9 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,14 @@ def _arg_values(self, interval, grid=None, args=None, **kwargs):
# may represent sets of legal values. If that's the case, here we just
# pick one. Note that we sort for determinism
try:
loc_minv = loc_minv.start
loc_minv = loc_minv.stop
except AttributeError:
try:
loc_minv = sorted(loc_minv).pop(0)
except TypeError:
pass
try:
loc_maxv = loc_maxv.start
loc_maxv = loc_maxv.stop
except AttributeError:
try:
loc_maxv = sorted(loc_maxv).pop(0)
Expand Down Expand Up @@ -859,7 +859,8 @@ def _arg_defaults(self, _min=None, size=None, alias=None):
factor = defaults[dim._factor.name] = dim._factor.data
except AttributeError:
factor = dim._factor
defaults[dim.parent.max_name] = range(1, factor*(size))

defaults[dim.parent.max_name] = range(1, factor*size)

return defaults

Expand Down Expand Up @@ -983,8 +984,7 @@ def bound_symbols(self):
return set(self.parent.bound_symbols)

def _arg_defaults(self, alias=None, **kwargs):
dim = alias or self
return {dim.parent.size_name: range(self.symbolic_size, np.iinfo(np.int64).max)}
return {}

def _arg_values(self, *args, **kwargs):
return {}
Expand Down Expand Up @@ -1466,10 +1466,7 @@ def _arg_defaults(self, _min=None, size=None, **kwargs):
A SteppingDimension does not know its max point and therefore
does not have a size argument.
"""
args = {self.parent.min_name: _min}
if size:
args[self.parent.size_name] = range(size-1, np.iinfo(np.int32).max)
return args
return {self.parent.min_name: _min}

def _arg_values(self, *args, **kwargs):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_buffering.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def test_over_injection():

# Check generated code
assert len(retrieve_iteration_tree(op1)) == \
7 + int(configuration['language'] != 'C')
8 + int(configuration['language'] != 'C')
buffers = [i for i in FindSymbols().visit(op1) if i.is_Array]
assert len(buffers) == 1

Expand Down
25 changes: 25 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from sympy import Float

from conftest import assert_structure
from devito import (Grid, Operator, Dimension, SparseFunction, SparseTimeFunction,
Function, TimeFunction, DefaultDimension, Eq,
PrecomputedSparseFunction, PrecomputedSparseTimeFunction,
Expand Down Expand Up @@ -734,3 +735,27 @@ class SparseFirst(SparseFunction):
op(time_M=10)
expected = 10*11/2 # n (n+1)/2
assert np.allclose(s.data, expected)


def test_inject_function():
nt = 11

grid = Grid(shape=(5, 5))
u = TimeFunction(name="u", grid=grid, time_order=2)
src = SparseTimeFunction(name="src", grid=grid, nt=nt, npoint=1,
coordinates=[[0.5, 0.5]])

nfreq = 5
freq_dim = DefaultDimension(name="freq", default_value=nfreq)
omega = Function(name="omega", dimensions=(freq_dim,), shape=(nfreq,), grid=grid)
omega.data.fill(1.)

inj = src.inject(field=u.forward, expr=omega)

op = Operator([inj])

assert_structure(op, ['p_src', 't', 't,p_src,freq', 't,p_src,freq,rsrcx,rsrcy'],
'p_src,t,p_src,freq,rsrcx,rsrcy')

op(time_M=0)
assert u.data[1, 2, 2] == nfreq

0 comments on commit 6212a00

Please sign in to comment.