From 7af2372475c2f4b80017be15556b5e361b8ab689 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Mon, 21 Oct 2024 08:34:45 +0000 Subject: [PATCH] compiler: Improve codegen aesthetics of cire-rotate --- devito/passes/iet/misc.py | 34 +++++++++++++++++++++++++++++++--- tests/test_dle.py | 16 ++++++++-------- tests/test_dse.py | 25 ++++++++++++++++--------- 3 files changed, 55 insertions(+), 20 deletions(-) diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 4617b47336..e9d8308336 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -5,9 +5,10 @@ import sympy from devito.finite_differences import Max, Min -from devito.ir import (Any, Forward, List, Prodder, FindApplications, FindNodes, - FindSymbols, Transformer, Uxreplace, filter_iterations, - retrieve_iteration_tree, pull_dims) +from devito.ir import (Any, Forward, DummyExpr, Iteration, List, Prodder, + FindApplications, FindNodes, FindSymbols, Transformer, + Uxreplace, filter_iterations, retrieve_iteration_tree, + pull_dims) from devito.passes.iet.engine import iet_pass from devito.ir.iet.efunc import DeviceFunction, EntryFunction from devito.symbolics import (ValueLimit, evalrel, has_integer_args, limits_mapper, @@ -231,10 +232,13 @@ def minimize_symbols(iet): * Remove redundant ModuloDimensions (e.g., due to using the `save=Buffer(2)` API) + * Simplify Iteration headers (e.g., ModuloDimensions with identical + starting point and step) * Abridge SubDimension names where possible to declutter generated loop nests and shrink indices """ iet = remove_redundant_moddims(iet) + iet = simplify_iteration_headers(iet) iet = abridge_dim_names(iet) return iet, {} @@ -264,6 +268,30 @@ def remove_redundant_moddims(iet): return iet +def simplify_iteration_headers(iet): + mapper = {} + for i in FindNodes(Iteration).visit(iet): + candidates = [d for d in i.uindices + if d.is_Modulo and d.symbolic_min == d.symbolic_incr] + + # Don't touch `t0, t1, ...` for codegen aesthetics and to avoid + # massive changes in the test suite + candidates = [d for d in candidates + if not any(dd.is_Time for dd in d._defines)] + + if not candidates: + continue + + uindices = [d for d in i.uindices if d not in candidates] + stmts = [DummyExpr(d, d.symbolic_incr, init=True) for d in candidates] + + mapper[i] = i._rebuild(nodes=tuple(stmts) + i.nodes, uindices=uindices) + + iet = Transformer(mapper, nested=True).visit(iet) + + return iet + + @singledispatch def abridge_dim_names(iet): return iet diff --git a/tests/test_dle.py b/tests/test_dle.py index efc41bd2bc..45583d4346 100644 --- a/tests/test_dle.py +++ b/tests/test_dle.py @@ -1319,17 +1319,17 @@ def test_multiple_subnests_v1(self): bns, _ = assert_blocking(op, {'x0_blk0'}) trees = retrieve_iteration_tree(bns['x0_blk0']) - assert len(trees) == 2 + assert len(trees) == 4 - assert trees[0][0] is trees[1][0] - assert trees[0][0].pragmas[0].ccode.value ==\ + assert len(set(i.root for i in trees)) == 1 + assert trees[-2].root.pragmas[0].ccode.value ==\ 'omp for collapse(2) schedule(dynamic,1)' - assert not trees[0][2].pragmas - assert not trees[0][3].pragmas - assert trees[0][4].pragmas[0].ccode.value ==\ + assert not trees[-2][2].pragmas + assert not trees[-2][3].pragmas + assert trees[-2][4].pragmas[0].ccode.value ==\ 'omp parallel for schedule(dynamic,1) num_threads(nthreads_nested)' - assert not trees[1][2].pragmas - assert trees[1][3].pragmas[0].ccode.value ==\ + assert not trees[-1][2].pragmas + assert trees[-1][3].pragmas[0].ccode.value ==\ 'omp parallel for schedule(dynamic,1) num_threads(nthreads_nested)' @pytest.mark.parametrize('blocklevels', [1, 2]) diff --git a/tests/test_dse.py b/tests/test_dse.py index 36676a88a5..49cece3b34 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -1131,13 +1131,13 @@ def test_from_different_nests(self, rotate): # Check code generation bns, _ = assert_blocking(op1, {'x0_blk0', 'x1_blk0'}) trees = retrieve_iteration_tree(bns['x0_blk0']) - assert len(trees) == 2 - assert trees[0][-1].nodes[0].body[0].write.is_Array - assert trees[1][-1].nodes[0].body[0].write is u + assert len(trees) == 4 if rotate else 2 + assert trees[-2][-1].nodes[0].body[0].write.is_Array + assert trees[-1][-1].nodes[0].body[0].write is u trees = retrieve_iteration_tree(bns['x1_blk0']) - assert len(trees) == 2 - assert trees[0][-1].nodes[0].body[0].write.is_Array - assert trees[1][-1].nodes[0].body[0].write is v + assert len(trees) == 4 if rotate else 2 + assert trees[-2][-1].nodes[0].body[0].write.is_Array + assert trees[-1][-1].nodes[0].body[0].write is v # Check numerical output op0(time_M=1) @@ -2093,9 +2093,12 @@ def test_maxpar_option(self, rotate): # Check code generation bns, _ = assert_blocking(op1, {'x0_blk0'}) trees = retrieve_iteration_tree(bns['x0_blk0']) - assert len(trees) == 2 + if rotate: + assert len(trees) == 5 + else: + assert len(trees) == 2 + assert trees[0][2] is not trees[1][2] assert trees[0][1] is trees[1][1] - assert trees[0][2] is not trees[1][2] # Check numerical output op0.apply(time_M=2) @@ -2241,7 +2244,11 @@ def test_blocking_options(self, rotate): if rotate: assert_structure( op1, - prefix + ['t,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,xc,y,z', + prefix + ['t,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x', + 't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,xc', + 't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,xc,y,z', + 't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,y', + 't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,y,yc', 't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,y,yc,z', 't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,y,z'], 't,x0_blk0,y0_blk0,x0_blk1,y0_blk1,x,xc,y,z,y,yc,z,z'