Skip to content

Commit

Permalink
dsl: added all_rules to return all required rules including modified …
Browse files Browse the repository at this point in the history
…user-supplied rules
  • Loading branch information
EdCaunt committed Mar 23, 2021
1 parent 4471237 commit 44d70e1
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 6 deletions.
49 changes: 46 additions & 3 deletions devito/finite_differences/coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from devito.tools import filter_ordered, as_tuple
from devito.symbolics.search import retrieve_dimensions

__all__ = ['Coefficient', 'Substitutions', 'default_rules']
__all__ = ['Coefficient', 'Substitutions', 'default_rules', 'all_rules']


class Coefficient(object):
Expand Down Expand Up @@ -227,6 +227,34 @@ def generate_subs(i):

return rules

def update_rules(self, obj):
"""Update the specified rules to reflect staggering in an equation"""
# Determine which 'rules' are expected
sym = get_sym(self._function_list)
terms = obj.find(sym)
args_expected = filter_ordered(term.args[1:] for term in terms)
args_expected_dim = [(arg[0], arg[1], retrieve_dimensions(arg[2])[0])
for arg in args_expected]

# Modify dictionary keys where expected index does not match index in rules
rules = self._rules.copy() # Get a copy to modify, to preserve base rules
for rule in self._rules:
rule_arg = rule.args[1:]
rule_arg_dim = (rule_arg[0], rule_arg[1],
retrieve_dimensions(rule_arg[2])[0])
if rule_arg_dim in args_expected_dim and rule_arg not in args_expected:
# Rule matches expected in terms of dimensions, but index is
# mismatched (due to staggering of equation)

# Find index in args_expected_dim
pos = args_expected_dim.index(rule_arg_dim)
# Replace key in rules with one using modified index taken from
# the expected
replacement = rule.args[:-1] + (args_expected[pos][-1],)
rules[sym(*replacement)] = rules.pop(rule)

return rules


def default_rules(obj, functions):

Expand Down Expand Up @@ -262,9 +290,10 @@ def generate_subs(deriv_order, function, index):
args_present = filter_ordered(term.args[1:] for term in terms)

subs = obj.substitutions

if subs:
args_provided = [(i.deriv_order, i.function, i.index)
for i in subs.coefficients]
# Check against the updated rules when determining rules the user has provided
args_provided = filter_ordered(rule.args[1:] for rule in subs.update_rules(obj))
else:
args_provided = []

Expand All @@ -289,3 +318,17 @@ def get_sym(functions):
pass
# Shouldn't arrive here
raise TypeError("Failed to retreive symbol")


def all_rules(obj, functions):
"""Return all substitution rules for an Eq"""
# Default rules
d_rules = default_rules(obj, functions)
# User rules
subs = obj.substitutions
if subs:
u_rules = subs.update_rules(obj)
else:
u_rules = {}

return {**d_rules, **u_rules}
6 changes: 3 additions & 3 deletions devito/types/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from cached_property import cached_property

from devito.finite_differences import default_rules
from devito.finite_differences import all_rules
from devito.logger import warning
from devito.tools import as_tuple
from devito.types.lazy import Evaluable
Expand Down Expand Up @@ -95,9 +95,9 @@ def evaluate(self):
if eq._uses_symbolic_coefficients:
# NOTE: As Coefficients.py is expanded we will not want
# all rules to be expunged during this procress.
rules = default_rules(eq, eq._symbolic_functions)
rules = all_rules(eq, eq._symbolic_functions)
try:
eq = eq.xreplace({**eq.substitutions.rules, **rules})
eq = eq.xreplace(rules)
except AttributeError:
if bool(rules):
eq = eq.xreplace(rules)
Expand Down

0 comments on commit 44d70e1

Please sign in to comment.