From 1563b807b8c4c2f810d611b01e22eede05a03740 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Mon, 9 Sep 2024 15:13:58 +0000 Subject: [PATCH] compiler: Add wrapper for subs vs uxreplace --- devito/symbolics/manipulation.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index 1762c4250d..11a24d5ea7 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -6,19 +6,21 @@ from sympy.core.add import _addsort from sympy.core.mul import _mulsort -from devito.finite_differences.differentiable import EvalDerivative +from devito.finite_differences.differentiable import ( + EvalDerivative, IndexDerivative +) from devito.symbolics.extended_sympy import DefFunction, rfunc from devito.symbolics.queries import q_leaf from devito.symbolics.search import retrieve_indexed, retrieve_functions from devito.tools import as_list, as_tuple, flatten, split, transitive_closure -from devito.types.basic import Basic +from devito.types.basic import Basic, Indexed from devito.types.array import ComponentAccess from devito.types.equation import Eq from devito.types.relational import Le, Lt, Gt, Ge __all__ = ['xreplace_indices', 'pow_to_mul', 'indexify', 'subs_op_args', - 'normalize_args', 'uxreplace', 'Uxmapper', 'reuse_if_untouched', - 'evalrel', 'flatten_args'] + 'normalize_args', 'uxreplace', 'Uxmapper', 'subs_if_composite', + 'reuse_if_untouched', 'evalrel', 'flatten_args'] def uxreplace(expr, rule): @@ -244,6 +246,20 @@ def add(self, expr, make, terms=None): self[base] = self.extracted[base] = make() +def subs_if_composite(expr, subs): + """ + Call `expr.subs(subs)` if `subs` contain composite expressions, that is + expressions that can be part of larger expressions of the same type (e.g., + `a*b` could be part of `a*b*c`, while `a[1]` cannot be part of a "larger + Indexed"). Instead, if `subs` consists of just "primitive" expressions, then + resort to the much faster `uxreplace`. + """ + if all(isinstance(i, (Indexed, IndexDerivative)) for i in subs): + return uxreplace(expr, subs) + else: + return expr.subs(subs) + + def xreplace_indices(exprs, mapper, key=None): """ Replace array indices in expressions.