Skip to content

Commit

Permalink
compiler: fix corner case reduction
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Sep 29, 2023
1 parent 7887f86 commit bd54b78
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 3 deletions.
4 changes: 2 additions & 2 deletions devito/ir/clusters/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ def __init__(self, func, mode='dense'):
self.func = func

if mode == 'dense':
self.cond = lambda c: c.is_dense
self.cond = lambda c: c.is_dense or not c.is_sparse
elif mode == 'sparse':
self.cond = lambda c: not c.is_dense
self.cond = lambda c: c.is_sparse
else:
self.cond = lambda c: True

Expand Down
28 changes: 27 additions & 1 deletion tests/test_dle.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
retrieve_iteration_tree, Expression)
from devito.passes.iet.languages.openmp import Ompizer, OmpRegion
from devito.tools import as_tuple
from devito.types import Scalar
from devito.types import Scalar, Symbol


def get_blocksizes(op, opt, grid, blockshape, level=0):
Expand Down Expand Up @@ -789,6 +789,32 @@ def test_reduction_local(self):

assert n.data[0] == 11*11

def test_mapify_reduction_sparse(self):
grid = Grid((11, 11))
s = SparseTimeFunction(name="s", grid=grid, npoint=1, nt=11)
s.data.fill(1.)
r = Symbol(name="r", dtype=np.float32)
n0 = Function(name="n0", dimensions=(Dimension("noi"),), shape=(1,))

eqns = [Eq(r, 0), Inc(r, s*s), Eq(n0[0], r)]
op0 = Operator(eqns)
op1 = Operator(eqns, opt=('advanced', {'mapify-reduce': True}))

expr0 = FindNodes(Expression).visit(op0)
assert len(expr0) == 3
assert expr0[1].is_reduction

expr1 = FindNodes(Expression).visit(op1)
assert len(expr1) == 4
assert expr1[1].expr.lhs.indices == s.indices
assert expr1[2].expr.rhs.is_Indexed
assert expr1[2].is_reduction

op0()
assert n0.data[0] == 11
op1()
assert n0.data[0] == 11

def test_array_max_reduction(self):
"""
Test generation of OpenMP sum-reduction clauses involving Function's.
Expand Down

0 comments on commit bd54b78

Please sign in to comment.