From 31344be9fbd6e4b17bdcccc0ae33d9a97ae403fe Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 31 Oct 2024 10:08:01 +0000 Subject: [PATCH] compiler: Improve CIRE's cost model --- devito/passes/clusters/aliases.py | 55 +++++++++++++++++++++++--- tests/test_dse.py | 66 ++++++++++++++++++++++++++++++- 2 files changed, 114 insertions(+), 7 deletions(-) diff --git a/devito/passes/clusters/aliases.py b/devito/passes/clusters/aliases.py index 8b7f610564..ca5343c9ac 100644 --- a/devito/passes/clusters/aliases.py +++ b/devito/passes/clusters/aliases.py @@ -958,15 +958,58 @@ def pick_best(variants): flops = flops0 + flops1 - # Data movement in the two sweeps - indexeds0 = search([sa.pivot for sa in i.schedule], Indexed) + # Estimate the data movement in the two sweeps + + # With cross-loop blocking, a Function appearing in both sweeps is + # much more likely to be in cache during the second sweep, hence + # we count it only once + functions0 = set() + functions1 = set() + for sa in i.schedule: + indexeds0 = search(sa.pivot, Indexed) + + if any(d.is_Block for d in sa.ispace.itdims): + functions1.update({i.function for i in indexeds0}) + else: + functions0.update({i.function for i in indexeds0}) + indexeds1 = search(i.exprs, Indexed) + functions1.update({i.function for i in indexeds1}) + + nfunctions0 = len(functions0) + nfunctions1 = len(functions1) + + # All temporaries impact data movement, but some kind of temporaries + # are more likely to be in cache than others, so they are given a + # lighter weight + for ii in indexeds1: + grid = ii.function.grid + if grid is None: + continue - ntemps = len(i.schedule) - nfunctions0 = len({i.function for i in indexeds0}) - nfunctions1 = len({i.function for i in indexeds1}) + ntemps = 0 + for sa in i.schedule: + if len(sa.writeto) < grid.dim: + # Tiny temporary, extremely likely to be in cache, hardly + # impacting data movement in a significant way + ntemps += 0.1 + elif any(d.is_Block for d in sa.writeto.itdims): + # Cross-loop blocking temporary, likely to be in some level + # of cache (but unlikely to be in the fastest level) + ntemps += 1 + else: + # Grid-size temporary, likely _not_ to be in cache, and + # therefore requiring at least two costly accesses per + # grid point + ntemps += 2 + + ntemps = int(ntemps) + + break + else: + ntemps = len(i.schedule) - ws = ntemps*2 + nfunctions0 + nfunctions1 + ws = ntemps + nfunctions0 + nfunctions1 if best is None: best, best_flops, best_ws = i, flops, ws diff --git a/tests/test_dse.py b/tests/test_dse.py index 49cece3b34..33f93bf5e5 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -1915,6 +1915,70 @@ def test_tti_adjoint_akin_v2(self): # all redundancies have been detected correctly assert summary1[('section1', None)].ops == 75 + @switchconfig(profiling='advanced') + def test_tti_adjoint_akin_v3(self): + so = 8 + fd_order = 2 + + grid = Grid(shape=(20, 20, 20)) + x, y, z = grid.dimensions + + vx = TimeFunction(name="vx", grid=grid, space_order=so) + vy = TimeFunction(name="vy", grid=grid, space_order=so) + vz = TimeFunction(name="vz", grid=grid, space_order=so) + txy = TimeFunction(name="txy", grid=grid, space_order=so) + txz = TimeFunction(name="txz", grid=grid, space_order=so) + theta = Function(name='theta', grid=grid, space_order=so) + phi = Function(name='phi', grid=grid, space_order=so) + + r00 = cos(theta)*cos(phi) + r01 = cos(theta)*sin(phi) + r02 = -sin(theta) + r10 = -sin(phi) + r11 = cos(phi) + r12 = cos(theta) + r20 = sin(theta)*cos(phi) + r21 = sin(theta)*sin(phi) + r22 = cos(theta) + + def foo0(field): + return ((r00 * field).dx(x0=x+x.spacing/2) + + Derivative(r01 * field, x, deriv_order=0, fd_order=fd_order, + x0=x+x.spacing/2).dy(x0=y) + + Derivative(r02 * field, x, deriv_order=0, fd_order=fd_order, + x0=x+x.spacing/2).dz(x0=z)) + + def foo1(field): + return (Derivative(r10 * field, y, deriv_order=0, fd_order=fd_order, + x0=y+y.spacing/2).dx(x0=x) + + (r11 * field).dy(x0=y+y.spacing/2) + + Derivative(r12 * field, y, deriv_order=0, fd_order=fd_order, + x0=y+y.spacing/2).dz(x0=z)) + + def foo2(field): + return (Derivative(r20 * field, z, deriv_order=0, fd_order=fd_order, + x0=z+z.spacing/2).dx(x0=x) + + Derivative(r21 * field, z, deriv_order=0, fd_order=fd_order, + x0=z+z.spacing/2).dy(x0=y) + + (r22 * field).dz(x0=z+z.spacing/2)) + + eqns = [Eq(txz.forward, txz + foo0(vz.forward) + foo2(vx.forward)), + Eq(txy.forward, txy + foo0(vy.forward) + foo1(vx.forward))] + + op = Operator(eqns, subs=grid.spacing_map, + opt=('advanced', {'openmp': True, + 'cire-rotate': True})) + + # Check code generation + bns, _ = assert_blocking(op, {'x0_blk0'}) + xs, ys, zs = get_params(op, 'x0_blk0_size', 'y0_blk0_size', 'z_size') + arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array] + assert len(arrays) == 10 + assert len([i for i in arrays if i.shape == (zs,)]) == 2 + assert len([i for i in arrays if i.shape == (9, zs)]) == 2 + + assert op._profiler._sections['section1'].sops == 184 + @pytest.mark.parametrize('rotate', [False, True]) @switchconfig(profiling='advanced') def test_nested_first_derivatives(self, rotate): @@ -2153,7 +2217,7 @@ def test_multiple_rotating_dims(self): # to jit-compile `op1`. However, we also check numerical correctness op1.apply(time_m=0, time_M=nt-2, dt=dt, u=u1, vx=vx1, vy=vy1) - assert np.allclose(u.data, u1.data, rtol=1e-5) + assert np.allclose(u.data, u1.data, rtol=1e-3) def test_maxpar_option_v2(self): """