Skip to content

Commit

Permalink
compiler: Patch unexpansion with double precision
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini authored and mloubout committed Oct 2, 2023
1 parent 5d98ac0 commit 7af8fa3
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 10 deletions.
3 changes: 0 additions & 3 deletions devito/ir/support/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,6 @@ def is_parallel(self, dims):
return any(len(self[d] & {PARALLEL, PARALLEL_INDEP}) > 0
for d in as_tuple(dims))

def is_parallel_atomic(self, dims):
return any(len(self[d] & {PARALLEL_IF_ATOMIC}) > 0 for d in as_tuple(dims))

def is_parallel_relaxed(self, dims):
return any(len(self[d] & PARALLELS) > 0 for d in as_tuple(dims))

Expand Down
2 changes: 1 addition & 1 deletion devito/passes/clusters/derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _core(expr, c, weights, mapper, sregistry):
try:
w = weights[k]
except KeyError:
w = weights[k] = w0._rebuild(name=name)
w = weights[k] = w0._rebuild(name=name, dtype=expr.dtype)
expr = uxreplace(expr, {w0.indexed: w.indexed})

dims = retrieve_dimensions(expr, deep=True)
Expand Down
7 changes: 1 addition & 6 deletions devito/passes/clusters/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,12 @@ def callback(self, clusters, prefix):
)
tip = nxt

<<<<<<< HEAD
if ispaceN:
ispace = IterationSpace.union(c.ispace, ispaceN)
ispace = IterationSpace.union(c.ispace, ispaceN, relations=relations)
processed.append(c.rebuild(ispace=ispace))
else:
processed.append(c)
seen.add(d)
=======
ispace = IterationSpace.union(c.ispace, ispaceN, relations=relations)
processed.append(c.rebuild(ispace=ispace))
>>>>>>> 68b8b814a (compiler: Speedup codegen by minimizing relations)

return processed

Expand Down
14 changes: 14 additions & 0 deletions tests/test_unexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,20 @@ def test_v1(self):
op.cfunction


class TestMisc(object):

def test_double_precision(self):
grid = Grid(shape=(10, 10, 10), dtype=np.float64)

u = TimeFunction(name='u', grid=grid, space_order=4)

eqns = Eq(u.forward, u.laplace + 1.)

op = Operator(eqns, opt=('advanced', {'expand': False}))

op.cfunction


def tti_sa_eqns(grid):
t = grid.stepping_dim
x, y, z = grid.dimensions
Expand Down

0 comments on commit 7af8fa3

Please sign in to comment.