Skip to content

Commit

Permalink
Merge pull request #2315 from devitocodes/patch-delta-compr-2
Browse files Browse the repository at this point in the history
compiler: Fix min/max reductions to be backend-portable
  • Loading branch information
mloubout authored Feb 14, 2024
2 parents 1428bbc + 3e504f9 commit 747c0fe
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 30 deletions.
66 changes: 54 additions & 12 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
import sympy

from devito.exceptions import InvalidOperator
from devito.finite_differences.elementary import Max, Min
from devito.ir.support import (Any, Backward, Forward, IterationSpace, erange,
pull_dims)
from devito.ir.equations import OpMin, OpMax
from devito.ir.clusters.analysis import analyze
from devito.ir.clusters.cluster import Cluster, ClusterGroup
from devito.ir.clusters.visitors import Queue, QueueStateful, cluster_pass
from devito.mpi.halo_scheme import HaloScheme, HaloTouch
from devito.symbolics import retrieve_indexed, uxreplace, xreplace_indices
from devito.symbolics import (limits_mapper, retrieve_indexed, uxreplace,
xreplace_indices)
from devito.tools import (DefaultOrderedDict, Stamp, as_mapper, flatten,
is_integer, timed_pass, toposort)
from devito.types import Array, Eq, Symbol
Expand Down Expand Up @@ -41,7 +44,7 @@ def clusterize(exprs, **kwargs):
# Determine relevant computational properties (e.g., parallelism)
clusters = analyze(clusters)

# Input normalization (e.g., SSA)
# Input normalization
clusters = normalize(clusters, **kwargs)

# Derive the necessary communications for distributed-memory parallelism
Expand Down Expand Up @@ -432,8 +435,11 @@ def normalize(clusters, **kwargs):
sregistry = kwargs['sregistry']

clusters = normalize_nested_indexeds(clusters, sregistry)
clusters = normalize_reductions_dense(clusters, sregistry, options)
clusters = normalize_reductions_sparse(clusters, sregistry, options)
if options['mapify-reduce']:
clusters = normalize_reductions_dense(clusters, sregistry)
else:
clusters = normalize_reductions_minmax(clusters)
clusters = normalize_reductions_sparse(clusters, sregistry)

return clusters

Expand Down Expand Up @@ -475,26 +481,62 @@ def pull_indexeds(expr, subs, mapper, parent=None):
return cluster.rebuild(processed)


def normalize_reductions_dense(cluster, sregistry, options):
@cluster_pass(mode='dense')
def normalize_reductions_minmax(cluster):
"""
Initialize the reduction variables to their neutral element and use them
to compute the reduction.
"""
dims = [d for d in cluster.ispace.itdims
if cluster.properties.is_parallel_atomic(d)]
if not dims:
return cluster

init = []
processed = []
for e in cluster.exprs:
lhs, rhs = e.args
f = lhs.function

if e.operation is OpMin:
if not f.is_Input:
expr = Eq(lhs, limits_mapper[lhs.dtype].max)
ispace = cluster.ispace.project(lambda i: i not in dims)
init.append(cluster.rebuild(exprs=expr, ispace=ispace))

processed.append(e.func(lhs, Min(lhs, rhs)))

elif e.operation is OpMax:
if not f.is_Input:
expr = Eq(lhs, limits_mapper[lhs.dtype].min)
ispce = cluster.ispace.project(lambda i: i not in dims)
init.append(cluster.rebuild(exprs=expr, ispace=ispce))

processed.append(e.func(lhs, Max(lhs, rhs)))

else:
processed.append(e)

return init + [cluster.rebuild(processed)]


def normalize_reductions_dense(cluster, sregistry):
"""
Extract the right-hand sides of reduction Eq's in to temporaries.
"""
return _normalize_reductions_dense(cluster, sregistry, options, {})
return _normalize_reductions_dense(cluster, sregistry, {})


@cluster_pass(mode='dense')
def _normalize_reductions_dense(cluster, sregistry, options, mapper):
opt_mapify_reduce = options['mapify-reduce']

def _normalize_reductions_dense(cluster, sregistry, mapper):
dims = [d for d in cluster.ispace.itdims
if cluster.properties.is_parallel_atomic(d)]

if not dims:
return cluster

processed = []
for e in cluster.exprs:
if e.is_Reduction and opt_mapify_reduce:
if e.is_Reduction:
# Transform `e` into what is in essence an explicit map-reduce
# For example, turn:
# `s += f(u[x], v[x], ...)`
Expand Down Expand Up @@ -529,7 +571,7 @@ def _normalize_reductions_dense(cluster, sregistry, options, mapper):


@cluster_pass(mode='sparse')
def normalize_reductions_sparse(cluster, sregistry, options):
def normalize_reductions_sparse(cluster, sregistry):
"""
Extract the right-hand sides of reduction Eq's in to temporaries.
"""
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def __repr__(self):
elif self.operation is OpInc:
return '%s += %s' % (self.lhs, self.rhs)
else:
return '%s = %s(%s, %s)' % (self.lhs, self.operation, self.lhs, self.rhs)
return '%s = %s(%s)' % (self.lhs, self.operation, self.rhs)

# Pickling support
__reduce_ex__ = Pickable.__reduce_ex__
Expand Down
3 changes: 1 addition & 2 deletions devito/passes/clusters/factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ def _collect_nested(expr):
return expr, {'funcs': expr}
elif expr.is_Pow:
return expr, {'pows': expr}
elif (expr.is_Symbol or
expr.is_Indexed or
elif (expr.is_Symbol or expr.is_Indexed or not expr.args or
isinstance(expr, (BasicWrapperMixin, AbstractObject))):
return expr, {}
elif expr.is_Add:
Expand Down
15 changes: 13 additions & 2 deletions devito/passes/iet/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import singledispatch

import cgen
import numpy as np
import sympy

from devito.finite_differences import Max, Min
Expand All @@ -9,7 +10,7 @@
filter_iterations, retrieve_iteration_tree, pull_dims)
from devito.passes.iet.engine import iet_pass
from devito.ir.iet.efunc import DeviceFunction, EntryFunction
from devito.symbolics import evalrel, has_integer_args
from devito.symbolics import ValueLimit, evalrel, has_integer_args, limits_mapper
from devito.tools import as_mapper, filter_ordered, split

__all__ = ['avoid_denormals', 'hoist_prodders', 'relax_incr_dimensions',
Expand Down Expand Up @@ -136,10 +137,20 @@ def relax_incr_dimensions(iet, options=None, **kwargs):

@iet_pass
def generate_macros(iet):
# Generate Macros from higher-level SymPy objects
applications = FindApplications().visit(iet)
headers = set().union(*[_generate_macros(i) for i in applications])

return iet, {'headers': headers}
# Some special Symbols may represent Macros defined in standard libraries,
# so we need to include the respective includes
limits = FindApplications(ValueLimit).visit(iet)
includes = set()
if limits & (set(limits_mapper[np.int32]) | set(limits_mapper[np.int64])):
includes.add('limits.h')
elif limits & (set(limits_mapper[np.float32]) | set(limits_mapper[np.float64])):
includes.add('float.h')

return iet, {'headers': headers, 'includes': includes}


@singledispatch
Expand Down
13 changes: 9 additions & 4 deletions devito/passes/iet/parpragma.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import defaultdict

import numpy as np
import cgen as c
from cached_property import cached_property
Expand Down Expand Up @@ -173,7 +175,7 @@ def _make_clause_reduction_from_imask(cls, reductions):
Build a string representing of a reduction clause given a list of
2-tuples `(symbol, ir.Operation)`.
"""
args = []
mapper = defaultdict(list)
for i, imask, r in reductions:
if i.is_Indexed:
f = i.function
Expand All @@ -188,10 +190,13 @@ def _make_clause_reduction_from_imask(cls, reductions):
else:
assert isinstance(k, tuple) and len(k) == 2
bounds.append('[%s:%s]' % k)
args.append('%s%s' % (i.name, ''.join(bounds)))
mapper[r.name].append('%s%s' % (i.name, ''.join(bounds)))
else:
args.append(str(i))
return 'reduction(%s:%s)' % (r.name, ','.join(args))
mapper[r.name].append(str(i))

args = ['reduction(%s:%s)' % (k, ','.join(v)) for k, v in mapper.items()]

return ' '.join(args)

@cached_property
def collapsed(self):
Expand Down
32 changes: 27 additions & 5 deletions devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@
from sympy import Expr, Function, Number, Tuple, sympify
from sympy.core.decorators import call_highest_priority

from devito.tools import (Pickable, as_tuple, is_integer, float2, float3, float4, # noqa
double2, double3, double4, int2, int3, int4)
from devito.finite_differences.elementary import Min, Max
from devito.tools import (Pickable, Bunch, as_tuple, is_integer, float2, # noqa
float3, float4, double2, double3, double4, int2, int3,
int4)
from devito.types import Symbol
from devito.types.basic import Basic

__all__ = ['CondEq', 'CondNe', 'IntDiv', 'CallFromPointer', # noqa
'CallFromComposite', 'FieldFromPointer', 'FieldFromComposite',
'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction',
'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String',
'Macro', 'Class', 'MacroArgument', 'CustomType', 'Deref',
'Namespace', 'Rvalue', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'Null',
'SizeOf', 'rfunc', 'cast_mapper', 'BasicWrapperMixin']
'Macro', 'Class', 'MacroArgument', 'CustomType', 'Deref', 'Namespace',
'Rvalue', 'INT', 'FLOAT', 'DOUBLE', 'VOID', 'Null', 'SizeOf', 'rfunc',
'cast_mapper', 'BasicWrapperMixin', 'ValueLimit', 'limits_mapper']


class CondEq(sympy.Eq):
Expand Down Expand Up @@ -497,6 +498,9 @@ def __str__(self):
def _hashable_content(self):
return (self.value,)

def _sympystr(self, printer):
return str(self)

# Pickling support
__reduce_ex__ = Pickable.__reduce_ex__

Expand Down Expand Up @@ -533,6 +537,24 @@ def __str__(self):
__repr__ = __str__


class ValueLimit(ReservedWord, sympy.Expr):

"""
Symbolic representation of the so called limits macros, which provide the
minimum and maximum limits for various types, such as INT_MIN, INT_MAX etc.
"""

pass


limits_mapper = {
np.int32: Bunch(min=ValueLimit('INT_MIN'), max=ValueLimit('INT_MAX')),
np.int64: Bunch(min=ValueLimit('LONG_MIN'), max=ValueLimit('LONG_MAX')),
np.float32: Bunch(min=-ValueLimit('FLT_MAX'), max=ValueLimit('FLT_MAX')),
np.float64: Bunch(min=-ValueLimit('DBL_MAX'), max=ValueLimit('DBL_MAX')),
}


class DefFunction(Function, Pickable):

"""
Expand Down
2 changes: 1 addition & 1 deletion devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def _print_Fallback(self, expr):
_print_MacroArgument = _print_Fallback
_print_IndexedBase = _print_Fallback
_print_IndexSum = _print_Fallback
_print_Keyword = _print_Fallback
_print_ReservedWord = _print_Fallback
_print_Basic = _print_Fallback


Expand Down
40 changes: 37 additions & 3 deletions tests/test_dle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from conftest import assert_structure, assert_blocking, _R, skipif
from devito import (Grid, Function, TimeFunction, SparseTimeFunction, SpaceDimension,
CustomDimension, Dimension, SubDimension,
PrecomputedSparseTimeFunction, Eq, Inc, ReduceMax, Operator,
configuration, dimensions, info, cos)
PrecomputedSparseTimeFunction, Eq, Inc, ReduceMin, ReduceMax,
Operator, configuration, dimensions, info, cos)
from devito.exceptions import InvalidArgument
from devito.ir.iet import (Iteration, FindNodes, IsPerfectIteration,
retrieve_iteration_tree, Expression)
Expand Down Expand Up @@ -877,7 +877,7 @@ def test_mapify_reduction_sparse(self):

def test_array_max_reduction(self):
"""
Test generation of OpenMP sum-reduction clauses involving Function's.
Test generation of OpenMP max-reduction clauses involving Function's.
"""
grid = Grid(shape=(3, 3, 3))
i = Dimension(name='i')
Expand All @@ -902,6 +902,40 @@ def test_array_max_reduction(self):
with pytest.raises(NotImplementedError):
Operator(eqn, opt=('advanced', {'openmp': True}))

def test_array_minmax_reduction(self):
"""
Test generation of OpenMP combined min- and max-reduction clauses
involving Function's.
"""
grid = Grid(shape=(3, 3, 3))
i = Dimension(name='i')

f = Function(name='f', grid=grid)
n = Function(name='n', grid=grid, shape=(2,), dimensions=(i,))
r0 = Symbol(name='r0', dtype=grid.dtype)
r1 = Symbol(name='r1', dtype=grid.dtype)

f.data[:] = np.arange(0, 27).reshape((3, 3, 3))

eqns = [ReduceMax(r0, f),
ReduceMin(r1, f),
Eq(n[0], r0),
Eq(n[1], r1)]

if not Ompizer._support_array_reduction(configuration['compiler']):
return

op = Operator(eqns)

if configuration['language'] == 'openmp':
iterations = FindNodes(Iteration).visit(op)
expected = "reduction(max:r0) reduction(min:r1)"
assert expected in iterations[0].pragmas[0].value

op()
assert n.data[0] == 26
assert n.data[1] == 0

def test_incs_no_atomic(self):
"""
Test that `Inc`'s don't get a `#pragma omp atomic` if performing
Expand Down

0 comments on commit 747c0fe

Please sign in to comment.