From d62b080986ef7807e0463ba8ec8a3ab7b16f29c2 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Fri, 20 Oct 2023 13:50:43 +0000 Subject: [PATCH] compiler: Fix unexpansion w custom coeffs --- devito/finite_differences/coefficients.py | 7 ++++++- devito/finite_differences/differentiable.py | 15 +++++++++++++++ tests/test_unexpansion.py | 17 +++++++++++++++++ 3 files changed, 38 insertions(+), 1 deletion(-) diff --git a/devito/finite_differences/coefficients.py b/devito/finite_differences/coefficients.py index 1a401b57e1..ab83e1b1c4 100644 --- a/devito/finite_differences/coefficients.py +++ b/devito/finite_differences/coefficients.py @@ -1,7 +1,7 @@ import numpy as np from cached_property import cached_property -from devito.finite_differences import generate_indices +from devito.finite_differences import Weights, generate_indices from devito.finite_differences.tools import numeric_weights, symbolic_weights from devito.tools import filter_ordered, as_tuple @@ -268,8 +268,13 @@ def generate_subs(deriv_order, function, index): return subs # Determine which 'rules' are missing + sym = get_sym(functions) terms = obj.find(sym) + for i in obj.find(Weights): + for w in i.weights: + terms.update(w.find(sym)) + args_present = filter_ordered(term.args[1:] for term in terms) subs = obj.substitutions diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 170852a50a..7ef93508cb 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -637,6 +637,21 @@ def spacings(self): weights = Array.initvalue + def _xreplace(self, rule): + if self in rule: + return rule[self], True + elif not rule: + return self, False + else: + try: + weights, flags = zip(*[i._xreplace(rule) for i in self.weights]) + if any(flags): + return self.func(initvalue=weights, function=None), True + except AttributeError: + # `float` weights + pass + return super()._xreplace(rule) + class IndexDerivative(IndexSum): diff --git a/tests/test_unexpansion.py b/tests/test_unexpansion.py index fa076096a3..1e269328c1 100644 --- a/tests/test_unexpansion.py +++ b/tests/test_unexpansion.py @@ -21,6 +21,23 @@ def test_backward_dt2(self): assert_structure(op, ['t,x,y'], 't,x,y') +class TestSymbolicCoefficients(object): + + def test_fallback_to_default(self): + grid = Grid(shape=(8, 8, 8)) + + u = TimeFunction(name='u', grid=grid, coefficients='symbolic', + space_order=4, time_order=2) + + eq = Eq(u.forward, u.dx2 + 1) + + op = Operator(eq, opt=('advanced', {'expand': False})) + + # Ensure all symbols have been resolved + op.arguments(dt=1, time_M=10) + op.cfunction + + class Test1Pass(object): def test_v0(self):