Skip to content

Commit

Permalink
Merge pull request #2311 from devitocodes/patch-delta-compute
Browse files Browse the repository at this point in the history
compiler: Optimize normalize_reductions_dense
  • Loading branch information
mloubout authored Feb 14, 2024
2 parents 8129669 + 75d8d81 commit 1428bbc
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,11 +475,15 @@ def pull_indexeds(expr, subs, mapper, parent=None):
return cluster.rebuild(processed)


@cluster_pass(mode='dense')
def normalize_reductions_dense(cluster, sregistry, options):
"""
Extract the right-hand sides of reduction Eq's in to temporaries.
"""
return _normalize_reductions_dense(cluster, sregistry, options, {})


@cluster_pass(mode='dense')
def _normalize_reductions_dense(cluster, sregistry, options, mapper):
opt_mapify_reduce = options['mapify-reduce']

dims = [d for d in cluster.ispace.itdims
Expand All @@ -499,20 +503,25 @@ def normalize_reductions_dense(cluster, sregistry, options):
# `s += r[x]`
# This makes it much easier to parallelize the map part regardless
# of the target backend
lhs, rhs = e.args

if e.lhs.function.is_Array:
if lhs.function.is_Array:
# Probably a compiler-generated reduction, e.g. via
# recursive compilation; it's an Array already, so nothing to do
processed.append(e)
elif rhs in mapper:
# Seen this RHS already, so reuse the Array that was created for it
processed.append(e.func(lhs, mapper[rhs].indexify()))
else:
# Here the LHS could be a Symbol or a user-level Function
# In the latter case we copy the data into a temporary Array
# because the Function might be padded, and reduction operations
# require, in general, the data values to be contiguous
name = sregistry.make_name()
a = Array(name=name, dtype=e.dtype, dimensions=dims)
processed.extend([Eq(a.indexify(), e.rhs),
e.func(e.lhs, a.indexify())])
a = mapper[rhs] = Array(name=name, dtype=e.dtype, dimensions=dims)

processed.extend([Eq(a.indexify(), rhs),
e.func(lhs, a.indexify())])
else:
processed.append(e)

Expand Down

0 comments on commit 1428bbc

Please sign in to comment.