Skip to content

Commit

Permalink
compiler: Patch cire-rotate
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Oct 22, 2024
1 parent 154140f commit 6c39cfb
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 4 deletions.
4 changes: 3 additions & 1 deletion devito/passes/clusters/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,9 @@ def optimize_schedule_rotations(schedule, sregistry):
iis = candidate.lower
iib = candidate.upper

ii = ModuloDimension('%sii' % d.root.name, ds, iis, incr=iib)
name = sregistry.make_name(prefix='%sii' % d.root.name)
ii = ModuloDimension(name, ds, iis, incr=iib)

cd = CustomDimension(name='%sc' % d.root.name, symbolic_min=ii,
symbolic_max=iib, symbolic_size=n)
dsi = ModuloDimension('%si' % ds.root.name, cd, cd + ds - iis, n)
Expand Down
11 changes: 9 additions & 2 deletions examples/performance/00_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1358,10 +1358,17 @@
" {\n",
" for (int x = x0_blk0; x <= MIN(x_M, x0_blk0 + x0_blk0_size - 1); x += 1)\n",
" {\n",
" for (int y = y0_blk0, ys = 0, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = -2; y <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y += 1, ys += 1, yr0 = (ys)%(5), yr1 = (ys + 3)%(5), yr2 = (ys + 4)%(5), yr3 = (ys + 1)%(5), yii = 2)\n",
" for (int y = y0_blk0, ys = 0, yii0 = -2; y <= MIN(y_M, y0_blk0 + y0_blk0_size - 1); y += 1, ys += 1, yii0 = 2)\n",
" {\n",
" for (int yc = yii, yi = (yc + ys + 2)%(5); yc <= 2; yc += 1, yi = (yc + ys + 2)%(5))\n",
" int yr0 = (ys)%(5);\n",
" int yr1 = (ys + 3)%(5);\n",
" int yr2 = (ys + 4)%(5);\n",
" int yr3 = (ys + 1)%(5);\n",
"\n",
" for (int yc = yii0; yc <= 2; yc += 1)\n",
" {\n",
" int yi = (yc + ys + 2)%(5);\n",
"\n",
" #pragma omp simd aligned(u:32)\n",
" for (int z = z_m; z <= z_M; z += 1)\n",
" {\n",
Expand Down
52 changes: 51 additions & 1 deletion tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
ConditionalDimension, DefaultDimension, Grid, Operator,
norm, grad, div, dimensions, switchconfig, configuration,
centered, first_derivative, solve, transpose, Abs, cos,
sin, sqrt, floor, Ge, Lt)
sin, sqrt, floor, Ge, Lt, Derivative)
from devito.exceptions import InvalidArgument, InvalidOperator
from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes,
FindSymbols, ParallelIteration, retrieve_iteration_tree)
Expand Down Expand Up @@ -2102,6 +2102,56 @@ def test_maxpar_option(self, rotate):
op1.apply(time_M=2, u=u1)
assert np.isclose(norm(u), norm(u1), rtol=1e-5)

def test_multiple_rotating_dims(self):
space_order = 8
grid = Grid(shape=(51, 51, 51))
x, y, z = grid.dimensions

dt = 0.1
nt = 5

u = TimeFunction(name="u", grid=grid, space_order=space_order)
vx = TimeFunction(name="vx", grid=grid, space_order=space_order)
vy = TimeFunction(name="vy", grid=grid, space_order=space_order)

f = Function(name='f', grid=grid, space_order=space_order)
g = Function(name='g', grid=grid, space_order=space_order)

expr0 = 1-cos(f)**2
expr1 = sin(f)*cos(f)
expr2 = sin(g)*cos(f)
expr3 = (1-cos(g))*sin(f)*cos(f)

stencil0 = ((expr0*vx.forward).dx(x0=x-x.spacing/2) +
Derivative(expr1*vx.forward, x, deriv_order=0, fd_order=2,
x0=x-x.spacing/2).dy(x0=y) +
Derivative(expr2*vx.forward, x, deriv_order=0, fd_order=2,
x0=x-x.spacing/2).dz(x0=z))
stencil1 = Derivative(expr3*vy.forward, y, deriv_order=0, fd_order=2,
x0=y-y.spacing/2).dx(x0=x)

eqns = [Eq(vx.forward, u*.1),
Eq(vy.forward, u*.1),
Eq(u.forward, stencil0 + stencil1 + .1)]

op0 = Operator(eqns)
op1 = Operator(eqns, opt=("advanced", {"cire-rotate": True}))

f.data_with_halo[:] = .3
g.data_with_halo[:] = .7

u1 = u.func(name='u1')
vx1 = vx.func(name='vx1')
vy1 = vy.func(name='vy1')

op0.apply(time_m=0, time_M=nt-2, dt=dt)

# NOTE: the main issue leading to this test was actually failing
# to jit-compile `op1`. However, we also check numerical correctness
op1.apply(time_m=0, time_M=nt-2, dt=dt, u=u1, vx=vx1, vy=vy1)

assert np.allclose(u.data, u1.data, rtol=1e-5)

def test_maxpar_option_v2(self):
"""
Another test for the compiler option `cire-maxpar=True`.
Expand Down

0 comments on commit 6c39cfb

Please sign in to comment.