diff --git a/devito/finite_differences/coefficients.py b/devito/finite_differences/coefficients.py index 4f7e48cf3a..0650f3dae5 100644 --- a/devito/finite_differences/coefficients.py +++ b/devito/finite_differences/coefficients.py @@ -6,7 +6,7 @@ from devito.tools import filter_ordered, as_tuple from devito.symbolics.search import retrieve_dimensions -__all__ = ['Coefficient', 'Substitutions', 'default_rules'] +__all__ = ['Coefficient', 'Substitutions', 'default_rules', 'all_rules'] class Coefficient(object): @@ -227,6 +227,34 @@ def generate_subs(i): return rules + def update_rules(self, obj): + """Update the specified rules to reflect staggering in an equation""" + # Determine which 'rules' are expected + sym = get_sym(self._function_list) + terms = obj.find(sym) + args_expected = filter_ordered(term.args[1:] for term in terms) + args_expected_dim = [(arg[0], arg[1], retrieve_dimensions(arg[2])[0]) + for arg in args_expected] + + # Modify dictionary keys where expected index does not match index in rules + rules = self._rules.copy() # Get a copy to modify, to preserve base rules + for rule in self._rules: + rule_arg = rule.args[1:] + rule_arg_dim = (rule_arg[0], rule_arg[1], + retrieve_dimensions(rule_arg[2])[0]) + if rule_arg_dim in args_expected_dim and rule_arg not in args_expected: + # Rule matches expected in terms of dimensions, but index is + # mismatched (due to staggering of equation) + + # Find index in args_expected_dim + pos = args_expected_dim.index(rule_arg_dim) + # Replace key in rules with one using modified index taken from + # the expected + replacement = rule.args[:-1] + (args_expected[pos][-1],) + rules[sym(*replacement)] = rules.pop(rule) + + return rules + def default_rules(obj, functions): @@ -262,9 +290,10 @@ def generate_subs(deriv_order, function, index): args_present = filter_ordered(term.args[1:] for term in terms) subs = obj.substitutions + if subs: - args_provided = [(i.deriv_order, i.function, i.index) - for i in subs.coefficients] + # Check against the updated rules when determining rules the user has provided + args_provided = filter_ordered(rule.args[1:] for rule in subs.update_rules(obj)) else: args_provided = [] @@ -289,3 +318,17 @@ def get_sym(functions): pass # Shouldn't arrive here raise TypeError("Failed to retreive symbol") + + +def all_rules(obj, functions): + """Return all substitution rules for an Eq""" + # Default rules + d_rules = default_rules(obj, functions) + # User rules + subs = obj.substitutions + if subs: + u_rules = subs.update_rules(obj) + else: + u_rules = {} + + return {**d_rules, **u_rules} diff --git a/devito/types/equation.py b/devito/types/equation.py index 4b86a59ec9..c53067ccf5 100644 --- a/devito/types/equation.py +++ b/devito/types/equation.py @@ -5,7 +5,7 @@ from cached_property import cached_property -from devito.finite_differences import default_rules +from devito.finite_differences import all_rules from devito.logger import warning from devito.tools import as_tuple from devito.types.lazy import Evaluable @@ -95,9 +95,9 @@ def evaluate(self): if eq._uses_symbolic_coefficients: # NOTE: As Coefficients.py is expanded we will not want # all rules to be expunged during this procress. - rules = default_rules(eq, eq._symbolic_functions) + rules = all_rules(eq, eq._symbolic_functions) try: - eq = eq.xreplace({**eq.substitutions.rules, **rules}) + eq = eq.xreplace(rules) except AttributeError: if bool(rules): eq = eq.xreplace(rules)