Skip to content

Commit

Permalink
compiler: Hotfix and simplify minimize_symbols
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Sep 11, 2024
1 parent 63dcfb1 commit c2a846a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 26 deletions.
3 changes: 3 additions & 0 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,10 @@ def visit_Iteration(self, o):
dimension = uxreplace(o.dim, self.mapper)
limits = [uxreplace(i, self.mapper) for i in o.limits]
pragmas = self._visit(o.pragmas)

uindices = [uxreplace(i, self.mapper) for i in o.uindices]
uindices = filter_ordered(i for i in uindices if isinstance(i, Dimension))

return o._rebuild(nodes=nodes, dimension=dimension, limits=limits,
pragmas=pragmas, uindices=uindices)

Expand Down
41 changes: 15 additions & 26 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import sympy

from devito.finite_differences import Max, Min
from devito.ir import (Any, Forward, Iteration, List, Prodder, FindApplications,
FindNodes, FindSymbols, Transformer, Uxreplace,
filter_iterations, retrieve_iteration_tree, pull_dims)
from devito.ir import (Any, Forward, List, Prodder, FindApplications, FindNodes,
FindSymbols, Transformer, Uxreplace, filter_iterations,
retrieve_iteration_tree, pull_dims)
from devito.passes.iet.engine import iet_pass
from devito.ir.iet.efunc import DeviceFunction, EntryFunction
from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper,
Expand Down Expand Up @@ -246,31 +246,20 @@ def remove_redundant_moddims(iet):
if not mds:
return iet

# ModuloDimensions are defined in Iteration headers, hence they must be
# removed from there first of all
mapper = {}
for n in FindNodes(Iteration).visit(iet):
candidates = [d for d in n.uindices if d in mds]

degenerates, others = split(candidates, lambda d: d.modulo == 1)
subs = {d: sympy.S.Zero for d in degenerates}

redundants = as_mapper(others, key=lambda d: d.offset % d.modulo)
for k, v in redundants.items():
chosen = v.pop(0)
subs.update({d: chosen for d in v})

if subs:
# Expunge the ModuloDimensions from the Iteration header
uindices = [d for d in n.uindices if d not in subs]
iteration = n._rebuild(uindices=uindices)
degenerates, others = split(mds, lambda d: d.modulo == 1)
subs = {d: sympy.S.Zero for d in degenerates}

# Replace the ModuloDimensions in the Iteration body
iteration = Uxreplace(subs).visit(iteration)
redundants = as_mapper(others, key=lambda d: d.offset % d.modulo)
for k, v in redundants.items():
chosen = v.pop(0)
subs.update({d: chosen for d in v})

mapper[n] = iteration

iet = Transformer(mapper, nested=True).visit(iet)
# Transform the `body`, rather than `iet`, to avoid applying substitutions
# to `iet.parameters`, so e.g. `..., t0, t1, t2, ...` remains unchanged
# instead of becoming `..., t0, t1, t1, ...`. The IET `engine` will then
# take care of cleaning up the `parameters` list
body = Uxreplace(subs).visit(iet.body)
iet = iet._rebuild(body=body)

return iet

Expand Down

0 comments on commit c2a846a

Please sign in to comment.