diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 4617b47336..56cbe73f52 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -5,9 +5,10 @@ import sympy from devito.finite_differences import Max, Min -from devito.ir import (Any, Forward, List, Prodder, FindApplications, FindNodes, - FindSymbols, Transformer, Uxreplace, filter_iterations, - retrieve_iteration_tree, pull_dims) +from devito.ir import (Any, Forward, DummyExpr, Iteration, List, Prodder, + FindApplications, FindNodes, FindSymbols, Transformer, + Uxreplace, filter_iterations, retrieve_iteration_tree, + pull_dims) from devito.passes.iet.engine import iet_pass from devito.ir.iet.efunc import DeviceFunction, EntryFunction from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper, @@ -231,10 +232,13 @@ def minimize_symbols(iet): * Remove redundant ModuloDimensions (e.g., due to using the `save=Buffer(2)` API) + * Simplify Iteration headers (e.g., ModuloDimensions with identical + starting point and step) * Abridge SubDimension names where possible to declutter generated loop nests and shrink indices """ iet = remove_redundant_moddims(iet) + iet = simplify_iteration_headers(iet) iet = abridge_dim_names(iet) return iet, {} @@ -264,6 +268,29 @@ def remove_redundant_moddims(iet): return iet +def simplify_iteration_headers(iet): + mapper = {} + for i in FindNodes(Iteration).visit(iet): + candidates = [d for d in i.uindices + if d.is_Modulo and d.symbolic_min == d.symbolic_incr] + + # Don't touch `t0, t1, ...` for codegen aesthetics and to avoid + # massive changes in the test suite + candidates = [d for d in candidates if not d.is_Time] + + if not candidates: + continue + + uindices = [d for d in i.uindices if d not in candidates] + stmts = [DummyExpr(d, d.symbolic_incr, init=True) for d in candidates] + + mapper[i] = i._rebuild(nodes=tuple(stmts) + i.nodes, uindices=uindices) + + iet = Transformer(mapper, nested=True).visit(iet) + + return iet + + @singledispatch def abridge_dim_names(iet): return iet diff --git a/tests/test_dse.py b/tests/test_dse.py index 7fd0298d93..e0450bc61d 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -1131,13 +1131,13 @@ def test_from_different_nests(self, rotate): # Check code generation bns, _ = assert_blocking(op1, {'x0_blk0', 'x1_blk0'}) trees = retrieve_iteration_tree(bns['x0_blk0']) - assert len(trees) == 2 - assert trees[0][-1].nodes[0].body[0].write.is_Array - assert trees[1][-1].nodes[0].body[0].write is u + assert len(trees) == 4 if rotate else 2 + assert trees[-2][-1].nodes[0].body[0].write.is_Array + assert trees[-1][-1].nodes[0].body[0].write is u trees = retrieve_iteration_tree(bns['x1_blk0']) - assert len(trees) == 2 - assert trees[0][-1].nodes[0].body[0].write.is_Array - assert trees[1][-1].nodes[0].body[0].write is v + assert len(trees) == 4 if rotate else 2 + assert trees[-2][-1].nodes[0].body[0].write.is_Array + assert trees[-1][-1].nodes[0].body[0].write is v # Check numerical output op0(time_M=1)