Skip to content

Commit

Permalink
compile: Improve codegen aesthetics of cire-rotate
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Oct 21, 2024
1 parent bac17a1 commit 77ac0de
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
33 changes: 30 additions & 3 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import sympy

from devito.finite_differences import Max, Min
from devito.ir import (Any, Forward, List, Prodder, FindApplications, FindNodes,
FindSymbols, Transformer, Uxreplace, filter_iterations,
retrieve_iteration_tree, pull_dims)
from devito.ir import (Any, Forward, DummyExpr, Iteration, List, Prodder,
FindApplications, FindNodes, FindSymbols, Transformer,
Uxreplace, filter_iterations, retrieve_iteration_tree,
pull_dims)
from devito.passes.iet.engine import iet_pass
from devito.ir.iet.efunc import DeviceFunction, EntryFunction
from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper,
Expand Down Expand Up @@ -231,10 +232,13 @@ def minimize_symbols(iet):
* Remove redundant ModuloDimensions (e.g., due to using the
`save=Buffer(2)` API)
* Simplify Iteration headers (e.g., ModuloDimensions with identical
starting point and step)
* Abridge SubDimension names where possible to declutter generated
loop nests and shrink indices
"""
iet = remove_redundant_moddims(iet)
iet = simplify_iteration_headers(iet)
iet = abridge_dim_names(iet)

return iet, {}
Expand Down Expand Up @@ -264,6 +268,29 @@ def remove_redundant_moddims(iet):
return iet


def simplify_iteration_headers(iet):
mapper = {}
for i in FindNodes(Iteration).visit(iet):
candidates = [d for d in i.uindices
if d.is_Modulo and d.symbolic_min == d.symbolic_incr]

# Don't touch `t0, t1, ...` for codegen aesthetics and to avoid
# massive changes in the test suite
candidates = [d for d in candidates if not d.is_Time]

if not candidates:
continue

uindices = [d for d in i.uindices if d not in candidates]
stmts = [DummyExpr(d, d.symbolic_incr, init=True) for d in candidates]

mapper[i] = i._rebuild(nodes=tuple(stmts) + i.nodes, uindices=uindices)

iet = Transformer(mapper, nested=True).visit(iet)

return iet


@singledispatch
def abridge_dim_names(iet):
return iet
Expand Down
12 changes: 6 additions & 6 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,13 +1131,13 @@ def test_from_different_nests(self, rotate):
# Check code generation
bns, _ = assert_blocking(op1, {'x0_blk0', 'x1_blk0'})
trees = retrieve_iteration_tree(bns['x0_blk0'])
assert len(trees) == 2
assert trees[0][-1].nodes[0].body[0].write.is_Array
assert trees[1][-1].nodes[0].body[0].write is u
assert len(trees) == 4 if rotate else 2
assert trees[-2][-1].nodes[0].body[0].write.is_Array
assert trees[-1][-1].nodes[0].body[0].write is u
trees = retrieve_iteration_tree(bns['x1_blk0'])
assert len(trees) == 2
assert trees[0][-1].nodes[0].body[0].write.is_Array
assert trees[1][-1].nodes[0].body[0].write is v
assert len(trees) == 4 if rotate else 2
assert trees[-2][-1].nodes[0].body[0].write.is_Array
assert trees[-1][-1].nodes[0].body[0].write is v

# Check numerical output
op0(time_M=1)
Expand Down

0 comments on commit 77ac0de

Please sign in to comment.