From 428615ac9806794975b331677b0255627518e866 Mon Sep 17 00:00:00 2001 From: mloubout Date: Fri, 15 Sep 2023 09:19:22 -0400 Subject: [PATCH] api: process injected expression dimensions in case it's not the sparse function --- devito/builtins/initializers.py | 7 ++++++- devito/operations/interpolators.py | 23 ++++++++++++++++------- devito/types/dimension.py | 12 ++++-------- tests/test_buffering.py | 2 +- tests/test_interpolation.py | 25 +++++++++++++++++++++++++ 5 files changed, 52 insertions(+), 17 deletions(-) diff --git a/devito/builtins/initializers.py b/devito/builtins/initializers.py index f338e194e1e..83bad735fae 100644 --- a/devito/builtins/initializers.py +++ b/devito/builtins/initializers.py @@ -77,7 +77,12 @@ def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs): symbolic_max=d.symbolic_max + h.right) eqs = [eq.xreplace(subs) for eq in eqs] - dv.Operator(eqs, name=name, **kwargs)() + op = dv.Operator(eqs, name=name, **kwargs) + try: + op() + except ValueError: + # Corner case such as assign(u, v) with v a Buffered TimeFunction + op(time_M=f._time_size) def smooth(f, g, axis=None): diff --git a/devito/operations/interpolators.py b/devito/operations/interpolators.py index 92bc3923926..d112eef7151 100644 --- a/devito/operations/interpolators.py +++ b/devito/operations/interpolators.py @@ -169,11 +169,17 @@ def _rdim(self): return DimensionTuple(*rdims, getters=self._gdims) - def _augment_implicit_dims(self, implicit_dims): + def _augment_implicit_dims(self, implicit_dims, extras=None): + if extras is not None: + extra = set([i for v in extras for i in v.dimensions]) - set(self._gdims) + extra = tuple(extra) + else: + extra = tuple() + if self.sfunction._sparse_position == -1: - return self.sfunction.dimensions + as_tuple(implicit_dims) + return self.sfunction.dimensions + as_tuple(implicit_dims) + extra else: - return as_tuple(implicit_dims) + self.sfunction.dimensions + return as_tuple(implicit_dims) + self.sfunction.dimensions + extra def _coeff_temps(self, implicit_dims): return [] @@ -252,8 +258,6 @@ def _interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None): interpolation expression, but that should be honored when constructing the operator. """ - implicit_dims = self._augment_implicit_dims(implicit_dims) - # Derivatives must be evaluated before the introduction of indirect accesses try: _expr = expr.evaluate @@ -263,6 +267,9 @@ def _interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None): variables = list(retrieve_function_carriers(_expr)) + # Implicit dimensions + implicit_dims = self._augment_implicit_dims(implicit_dims) + # List of indirection indices for all adjacent grid points idx_subs, temps = self._interp_idx(variables, implicit_dims=implicit_dims) @@ -295,8 +302,6 @@ def _inject(self, field, expr, implicit_dims=None): injection expression, but that should be honored when constructing the operator. """ - implicit_dims = self._augment_implicit_dims(implicit_dims) - # Make iterable to support inject((u, v), expr=expr) # or inject((u, v), expr=(expr1, expr2)) fields, exprs = as_tuple(field), as_tuple(expr) @@ -315,6 +320,10 @@ def _inject(self, field, expr, implicit_dims=None): _exprs = exprs variables = list(v for e in _exprs for v in retrieve_function_carriers(e)) + + # Implicit dimensions + implicit_dims = self._augment_implicit_dims(implicit_dims, variables) + variables = variables + list(fields) # List of indirection indices for all adjacent grid points diff --git a/devito/types/dimension.py b/devito/types/dimension.py index 6044f014690..548a26d012f 100644 --- a/devito/types/dimension.py +++ b/devito/types/dimension.py @@ -298,14 +298,14 @@ def _arg_values(self, interval, grid=None, args=None, **kwargs): # may represent sets of legal values. If that's the case, here we just # pick one. Note that we sort for determinism try: - loc_minv = loc_minv.start + loc_minv = loc_minv.stop except AttributeError: try: loc_minv = sorted(loc_minv).pop(0) except TypeError: pass try: - loc_maxv = loc_maxv.start + loc_maxv = loc_maxv.stop except AttributeError: try: loc_maxv = sorted(loc_maxv).pop(0) @@ -983,8 +983,7 @@ def bound_symbols(self): return set(self.parent.bound_symbols) def _arg_defaults(self, alias=None, **kwargs): - dim = alias or self - return {dim.parent.size_name: range(self.symbolic_size, np.iinfo(np.int64).max)} + return {} def _arg_values(self, *args, **kwargs): return {} @@ -1466,10 +1465,7 @@ def _arg_defaults(self, _min=None, size=None, **kwargs): A SteppingDimension does not know its max point and therefore does not have a size argument. """ - args = {self.parent.min_name: _min} - if size: - args[self.parent.size_name] = range(size-1, np.iinfo(np.int32).max) - return args + return {self.parent.min_name: _min} def _arg_values(self, *args, **kwargs): """ diff --git a/tests/test_buffering.py b/tests/test_buffering.py index 16f98b4f940..ba200d220c7 100644 --- a/tests/test_buffering.py +++ b/tests/test_buffering.py @@ -272,7 +272,7 @@ def test_over_injection(): # Check generated code assert len(retrieve_iteration_tree(op1)) == \ - 7 + int(configuration['language'] != 'C') + 8 + int(configuration['language'] != 'C') buffers = [i for i in FindSymbols().visit(op1) if i.is_Array] assert len(buffers) == 1 diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 3a22ca1db73..79e816a01b6 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -4,6 +4,7 @@ import pytest from sympy import Float +from conftest import assert_structure from devito import (Grid, Operator, Dimension, SparseFunction, SparseTimeFunction, Function, TimeFunction, DefaultDimension, Eq, PrecomputedSparseFunction, PrecomputedSparseTimeFunction, @@ -734,3 +735,27 @@ class SparseFirst(SparseFunction): op(time_M=10) expected = 10*11/2 # n (n+1)/2 assert np.allclose(s.data, expected) + + +def test_inject_function(): + nt = 11 + + grid = Grid(shape=(5, 5)) + u = TimeFunction(name="u", grid=grid, time_order=2) + src = SparseTimeFunction(name="src", grid=grid, nt=nt, npoint=1, + coordinates=[[0.5, 0.5]]) + + nfreq = 5 + freq_dim = DefaultDimension(name="freq", default_value=nfreq) + omega = Function(name="omega", dimensions=(freq_dim,), shape=(nfreq,), grid=grid) + omega.data.fill(1.) + + inj = src.inject(field=u.forward, expr=omega) + + op = Operator([inj]) + + assert_structure(op, ['p_src', 't', 't,p_src,freq', 't,p_src,freq,rsrcx,rsrcy'], + 'p_src,t,p_src,freq,rsrcx,rsrcy') + + op(time_M=0) + assert u.data[1, 2, 2] == nfreq