-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
compiler: Improve CIRE's cost model #2476
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -958,15 +958,58 @@ def pick_best(variants): | |
|
||
flops = flops0 + flops1 | ||
|
||
# Data movement in the two sweeps | ||
indexeds0 = search([sa.pivot for sa in i.schedule], Indexed) | ||
# Estimate the data movement in the two sweeps | ||
|
||
# With cross-loop blocking, a Function appearing in both sweeps is | ||
# much more likely to be in cache during the second sweep, hence | ||
# we count it only once | ||
functions0 = set() | ||
functions1 = set() | ||
for sa in i.schedule: | ||
indexeds0 = search(sa.pivot, Indexed) | ||
|
||
if any(d.is_Block for d in sa.ispace.itdims): | ||
functions1.update({i.function for i in indexeds0}) | ||
else: | ||
functions0.update({i.function for i in indexeds0}) | ||
|
||
indexeds1 = search(i.exprs, Indexed) | ||
functions1.update({i.function for i in indexeds1}) | ||
|
||
nfunctions0 = len(functions0) | ||
nfunctions1 = len(functions1) | ||
|
||
# All temporaries impact data movement, but some kind of temporaries | ||
# are more likely to be in cache than others, so they are given a | ||
# lighter weight | ||
for ii in indexeds1: | ||
grid = ii.function.grid | ||
if grid is None: | ||
continue | ||
|
||
ntemps = len(i.schedule) | ||
nfunctions0 = len({i.function for i in indexeds0}) | ||
nfunctions1 = len({i.function for i in indexeds1}) | ||
ntemps = 0 | ||
for sa in i.schedule: | ||
if len(sa.writeto) < grid.dim: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's ok, grid.dim is a number (should have been called grid.ndim, legacy...) and writeto is an IterationSpace so its length is the number of Dimensions in it |
||
# Tiny temporary, extremely likely to be in cache, hardly | ||
# impacting data movement in a significant way | ||
ntemps += 0.1 | ||
elif any(d.is_Block for d in sa.writeto.itdims): | ||
# Cross-loop blocking temporary, likely to be in some level | ||
# of cache (but unlikely to be in the fastest level) | ||
ntemps += 1 | ||
else: | ||
# Grid-size temporary, likely _not_ to be in cache, and | ||
# therefore requiring at least two costly accesses per | ||
# grid point | ||
ntemps += 2 | ||
|
||
ntemps = int(ntemps) | ||
|
||
break | ||
else: | ||
ntemps = len(i.schedule) | ||
|
||
ws = ntemps*2 + nfunctions0 + nfunctions1 | ||
ws = ntemps + nfunctions0 + nfunctions1 | ||
|
||
if best is None: | ||
best, best_flops, best_ws = i, flops, ws | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1915,6 +1915,70 @@ def test_tti_adjoint_akin_v2(self): | |
# all redundancies have been detected correctly | ||
assert summary1[('section1', None)].ops == 75 | ||
|
||
@switchconfig(profiling='advanced') | ||
def test_tti_adjoint_akin_v3(self): | ||
so = 8 | ||
fd_order = 2 | ||
|
||
grid = Grid(shape=(20, 20, 20)) | ||
x, y, z = grid.dimensions | ||
|
||
vx = TimeFunction(name="vx", grid=grid, space_order=so) | ||
vy = TimeFunction(name="vy", grid=grid, space_order=so) | ||
vz = TimeFunction(name="vz", grid=grid, space_order=so) | ||
txy = TimeFunction(name="txy", grid=grid, space_order=so) | ||
txz = TimeFunction(name="txz", grid=grid, space_order=so) | ||
theta = Function(name='theta', grid=grid, space_order=so) | ||
phi = Function(name='phi', grid=grid, space_order=so) | ||
|
||
r00 = cos(theta)*cos(phi) | ||
r01 = cos(theta)*sin(phi) | ||
r02 = -sin(theta) | ||
r10 = -sin(phi) | ||
r11 = cos(phi) | ||
r12 = cos(theta) | ||
r20 = sin(theta)*cos(phi) | ||
r21 = sin(theta)*sin(phi) | ||
r22 = cos(theta) | ||
|
||
def foo0(field): | ||
return ((r00 * field).dx(x0=x+x.spacing/2) + | ||
Derivative(r01 * field, x, deriv_order=0, fd_order=fd_order, | ||
x0=x+x.spacing/2).dy(x0=y) + | ||
Derivative(r02 * field, x, deriv_order=0, fd_order=fd_order, | ||
x0=x+x.spacing/2).dz(x0=z)) | ||
|
||
def foo1(field): | ||
return (Derivative(r10 * field, y, deriv_order=0, fd_order=fd_order, | ||
x0=y+y.spacing/2).dx(x0=x) + | ||
(r11 * field).dy(x0=y+y.spacing/2) + | ||
Derivative(r12 * field, y, deriv_order=0, fd_order=fd_order, | ||
x0=y+y.spacing/2).dz(x0=z)) | ||
|
||
def foo2(field): | ||
return (Derivative(r20 * field, z, deriv_order=0, fd_order=fd_order, | ||
x0=z+z.spacing/2).dx(x0=x) + | ||
Derivative(r21 * field, z, deriv_order=0, fd_order=fd_order, | ||
x0=z+z.spacing/2).dy(x0=y) + | ||
(r22 * field).dz(x0=z+z.spacing/2)) | ||
|
||
eqns = [Eq(txz.forward, txz + foo0(vz.forward) + foo2(vx.forward)), | ||
Eq(txy.forward, txy + foo0(vy.forward) + foo1(vx.forward))] | ||
|
||
op = Operator(eqns, subs=grid.spacing_map, | ||
opt=('advanced', {'openmp': True, | ||
'cire-rotate': True})) | ||
|
||
# Check code generation | ||
bns, _ = assert_blocking(op, {'x0_blk0'}) | ||
xs, ys, zs = get_params(op, 'x0_blk0_size', 'y0_blk0_size', 'z_size') | ||
arrays = [i for i in FindSymbols().visit(bns['x0_blk0']) if i.is_Array] | ||
assert len(arrays) == 10 | ||
assert len([i for i in arrays if i.shape == (zs,)]) == 2 | ||
assert len([i for i in arrays if i.shape == (9, zs)]) == 2 | ||
|
||
assert op._profiler._sections['section1'].sops == 184 | ||
|
||
@pytest.mark.parametrize('rotate', [False, True]) | ||
@switchconfig(profiling='advanced') | ||
def test_nested_first_derivatives(self, rotate): | ||
|
@@ -2153,7 +2217,7 @@ def test_multiple_rotating_dims(self): | |
# to jit-compile `op1`. However, we also check numerical correctness | ||
op1.apply(time_m=0, time_M=nt-2, dt=dt, u=u1, vx=vx1, vy=vy1) | ||
|
||
assert np.allclose(u.data, u1.data, rtol=1e-5) | ||
assert np.allclose(u.data, u1.data, rtol=1e-3) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What this much change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. highly unstable equations and input data, and now more aggressive refactorings |
||
|
||
def test_maxpar_option_v2(self): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would expect to see some change in counted flops in other tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this indeed affecting only the new variant you added?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what we had before was good already, here I've been surgical...