From c2a846a8eedbec70dc349232acf9fa22c581a61c Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 11 Sep 2024 06:52:11 +0000 Subject: [PATCH] compiler: Hotfix and simplify minimize_symbols --- devito/ir/iet/visitors.py | 3 +++ devito/passes/iet/misc.py | 41 ++++++++++++++------------------------- 2 files changed, 18 insertions(+), 26 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 54e9188e1a..505fe2e001 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -1229,7 +1229,10 @@ def visit_Iteration(self, o): dimension = uxreplace(o.dim, self.mapper) limits = [uxreplace(i, self.mapper) for i in o.limits] pragmas = self._visit(o.pragmas) + uindices = [uxreplace(i, self.mapper) for i in o.uindices] + uindices = filter_ordered(i for i in uindices if isinstance(i, Dimension)) + return o._rebuild(nodes=nodes, dimension=dimension, limits=limits, pragmas=pragmas, uindices=uindices) diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 0b4ed39149..4617b47336 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -5,9 +5,9 @@ import sympy from devito.finite_differences import Max, Min -from devito.ir import (Any, Forward, Iteration, List, Prodder, FindApplications, - FindNodes, FindSymbols, Transformer, Uxreplace, - filter_iterations, retrieve_iteration_tree, pull_dims) +from devito.ir import (Any, Forward, List, Prodder, FindApplications, FindNodes, + FindSymbols, Transformer, Uxreplace, filter_iterations, + retrieve_iteration_tree, pull_dims) from devito.passes.iet.engine import iet_pass from devito.ir.iet.efunc import DeviceFunction, EntryFunction from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper, @@ -246,31 +246,20 @@ def remove_redundant_moddims(iet): if not mds: return iet - # ModuloDimensions are defined in Iteration headers, hence they must be - # removed from there first of all - mapper = {} - for n in FindNodes(Iteration).visit(iet): - candidates = [d for d in n.uindices if d in mds] - - degenerates, others = split(candidates, lambda d: d.modulo == 1) - subs = {d: sympy.S.Zero for d in degenerates} - - redundants = as_mapper(others, key=lambda d: d.offset % d.modulo) - for k, v in redundants.items(): - chosen = v.pop(0) - subs.update({d: chosen for d in v}) - - if subs: - # Expunge the ModuloDimensions from the Iteration header - uindices = [d for d in n.uindices if d not in subs] - iteration = n._rebuild(uindices=uindices) + degenerates, others = split(mds, lambda d: d.modulo == 1) + subs = {d: sympy.S.Zero for d in degenerates} - # Replace the ModuloDimensions in the Iteration body - iteration = Uxreplace(subs).visit(iteration) + redundants = as_mapper(others, key=lambda d: d.offset % d.modulo) + for k, v in redundants.items(): + chosen = v.pop(0) + subs.update({d: chosen for d in v}) - mapper[n] = iteration - - iet = Transformer(mapper, nested=True).visit(iet) + # Transform the `body`, rather than `iet`, to avoid applying substitutions + # to `iet.parameters`, so e.g. `..., t0, t1, t2, ...` remains unchanged + # instead of becoming `..., t0, t1, t1, ...`. The IET `engine` will then + # take care of cleaning up the `parameters` list + body = Uxreplace(subs).visit(iet.body) + iet = iet._rebuild(body=body) return iet