Skip to content

Commit

Permalink
compiler: Improve CIRE's cost model
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Oct 31, 2024
1 parent f95007d commit dea7217
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 7 deletions.
55 changes: 49 additions & 6 deletions devito/passes/clusters/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,15 +958,58 @@ def pick_best(variants):

flops = flops0 + flops1

# Data movement in the two sweeps
indexeds0 = search([sa.pivot for sa in i.schedule], Indexed)
# Estimate the data movement in the two sweeps

# With cross-loop blocking, a Function appearing in both sweeps is
# much more likely to be in cache during the second sweep, hence
# we count it only once
functions0 = set()
functions1 = set()
for sa in i.schedule:
indexeds0 = search(sa.pivot, Indexed)

if any(d.is_Block for d in sa.ispace.itdims):
functions1.update({i.function for i in indexeds0})
else:
functions0.update({i.function for i in indexeds0})

indexeds1 = search(i.exprs, Indexed)
functions1.update({i.function for i in indexeds1})

nfunctions0 = len(functions0)
nfunctions1 = len(functions1)

# All temporaries impact data movement, but some kind of temporaries
# are more likely to be in cache than others, so they are given a
# lighter weight
for ii in indexeds1:
grid = ii.function.grid
if grid is None:
continue

ntemps = len(i.schedule)
nfunctions0 = len({i.function for i in indexeds0})
nfunctions1 = len({i.function for i in indexeds1})
ntemps = 0
for sa in i.schedule:
if len(sa.writeto) < grid.dim:
# Tiny temporary, extremely likely to be in cache, hardly
# impacting data movement in a significant way
ntemps += 0.1
elif any(d.is_Block for d in sa.writeto.itdims):
# Cross-loop blocking temporary, likely to be in some level
# of cache (but unlikely to be in the fastest level)
ntemps += 1
else:
# Grid-size temporary, likely _not_ to be in cache, and
# therefore requiring at least two costly accesses per
# grid point
ntemps += 2

ntemps = int(ntemps)

break
else:
ntemps = len(i.schedule)

ws = ntemps*2 + nfunctions0 + nfunctions1
ws = ntemps + nfunctions0 + nfunctions1

if best is None:
best, best_flops, best_ws = i, flops, ws
Expand Down
66 changes: 65 additions & 1 deletion tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1915,6 +1915,70 @@ def test_tti_adjoint_akin_v2(self):
# all redundancies have been detected correctly
assert summary1[('section1', None)].ops == 75

@switchconfig(profiling='advanced')
def test_tti_adjoint_akin_v3(self):
so = 8
fd_order = 2

grid = Grid(shape=(20, 20, 20))
x, y, z = grid.dimensions

vx = TimeFunction(name="vx", grid=grid, space_order=so)
vy = TimeFunction(name="vy", grid=grid, space_order=so)
vz = TimeFunction(name="vz", grid=grid, space_order=so)
txy = TimeFunction(name="txy", grid=grid, space_order=so)
txz = TimeFunction(name="txz", grid=grid, space_order=so)
theta = Function(name='theta', grid=grid, space_order=so)
phi = Function(name='phi', grid=grid, space_order=so)

r00 = 1 - (1 - cos(theta))*cos(phi)**2
r01 = -(1 - cos(theta))*sin(phi)*cos(phi)
r02 = -sin(theta)*cos(phi)
r10 = -(1 - cos(theta))*sin(phi)*cos(phi)
r11 = 1 - (1 - cos(theta))*sin(phi)**2
r12 = -sin(theta)*sin(phi)
r20 = sin(theta)*cos(phi)
r21 = sin(theta)*sin(phi)
r22 = cos(theta)

def foo0(field):
return ((r00 * field).dx(x0=x+x.spacing/2) +
Derivative(r01 * field, x, deriv_order=0, fd_order=fd_order,
x0=x+x.spacing/2).dy(x0=y) +
Derivative(r02 * field, x, deriv_order=0, fd_order=fd_order,
x0=x+x.spacing/2).dz(x0=z))

def foo1(field):
return (Derivative(r10 * field, y, deriv_order=0, fd_order=fd_order,
x0=y+y.spacing/2).dx(x0=x) +
(r11 * field).dy(x0=y+y.spacing/2) +
Derivative(r12 * field, y, deriv_order=0, fd_order=fd_order,
x0=y+y.spacing/2).dz(x0=z))

def foo2(field):
return (Derivative(r20 * field, z, deriv_order=0, fd_order=fd_order,
x0=z+z.spacing/2).dx(x0=x) +
Derivative(r21 * field, z, deriv_order=0, fd_order=fd_order,
x0=z+z.spacing/2).dy(x0=y) +
(r22 * field).dz(x0=z+z.spacing/2))

eqns = [Eq(txz.forward, txz + foo0(vz.forward) + foo2(vx.forward)),
Eq(txy.forward, txy + foo0(vy.forward) + foo1(vx.forward))]

op = Operator(eqns, subs=grid.spacing_map,
opt=('advanced', {'openmp': True,
'cire-rotate': True}))

# Check code generation
bns, _ = assert_blocking(op, {'x0_blk0'})
xs, ys, zs = get_params(op, 'x0_blk0_size', 'y0_blk0_size', 'z_size')
arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array]
assert len(arrays) == 13
assert len([i for i in arrays if i.shape == (zs,)]) == 2
assert len([i for i in arrays if i.shape == (9, zs)]) == 1

assert op._profiler._sections['section1'].sops == 179

@pytest.mark.parametrize('rotate', [False, True])
@switchconfig(profiling='advanced')
def test_nested_first_derivatives(self, rotate):
Expand Down Expand Up @@ -2153,7 +2217,7 @@ def test_multiple_rotating_dims(self):
# to jit-compile `op1`. However, we also check numerical correctness
op1.apply(time_m=0, time_M=nt-2, dt=dt, u=u1, vx=vx1, vy=vy1)

assert np.allclose(u.data, u1.data, rtol=1e-5)
assert np.allclose(u.data, u1.data, rtol=1e-3)

def test_maxpar_option_v2(self):
"""
Expand Down

0 comments on commit dea7217

Please sign in to comment.