diff --git a/devito/types/tensor.py b/devito/types/tensor.py index 6c94c5fc3af..62a555ef78b 100644 --- a/devito/types/tensor.py +++ b/devito/types/tensor.py @@ -173,7 +173,13 @@ def is_diagonal(self): for i in range(self.rows) if i != j]) def _evaluate(self, **kwargs): - return self.applyfunc(lambda x: getattr(x, 'evaluate', x)) + def _do_evaluate(x): + try: + expand = kwargs.get('expand', True) + return x._evaluate(expand=expand) + except AttributeError: + return x + return self.applyfunc(_do_evaluate) def values(self): if self.is_diagonal: diff --git a/tests/test_derivatives.py b/tests/test_derivatives.py index 8031932df55..5fa11a6fa02 100644 --- a/tests/test_derivatives.py +++ b/tests/test_derivatives.py @@ -766,6 +766,16 @@ def test_transpose(self): i0, = term.dimensions assert term.base == f.subs(x, x + i0*h_x) + def test_tensor_algebra(self): + grid = Grid(shape=(4, 4)) + + f = Function(name='f', grid=grid, space_order=4) + + v = grad(f)._evaluate(expand=False) + + assert all(isinstance(i, IndexDerivative) for i in v) + assert all(zip([Add(*i.args) for i in grad(f).evaluate], v.evaluate)) + def bypass_uneval(expr): unevals = expr.find(EvalDerivative)