Skip to content

Commit

Permalink
compiler: Patch mapify-reduce for SparseTimeFunctions
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini authored and mloubout committed Sep 9, 2024
1 parent 65b69e5 commit 4526d35
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
20 changes: 15 additions & 5 deletions devito/ir/clusters/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from devito.finite_differences.elementary import Max, Min
from devito.ir.support import (Any, Backward, Forward, IterationSpace, erange,
pull_dims, null_ispace)
from devito.ir.equations import OpMin, OpMax
from devito.ir.equations import OpMin, OpMax, identity_mapper
from devito.ir.clusters.analysis import analyze
from devito.ir.clusters.cluster import Cluster, ClusterGroup
from devito.ir.clusters.visitors import Queue, QueueStateful, cluster_pass
Expand Down Expand Up @@ -580,8 +580,8 @@ def normalize_reductions_minmax(cluster):
elif e.operation is OpMax:
if not f.is_Input:
expr = Eq(lhs, limits_mapper[lhs.dtype].min)
ispce = cluster.ispace.project(lambda i: i not in dims)
init.append(cluster.rebuild(exprs=expr, ispace=ispce))
ispace = cluster.ispace.project(lambda i: i not in dims)
init.append(cluster.rebuild(exprs=expr, ispace=ispace))

processed.append(e.func(lhs, Max(lhs, rhs)))

Expand Down Expand Up @@ -663,8 +663,18 @@ def _normalize_reductions_dense(cluster, mapper, sregistry, platform):
a = mapper[rhs] = Array(name=name, dtype=e.dtype, dimensions=dims,
grid=grid)

processed.extend([Eq(a.indexify(), rhs),
e.func(lhs, a.indexify())])
# Populate the Array (the "map" part)
processed.append(e.func(a.indexify(), rhs, operation=None))

# Set all untouched entried to the identity value if necessary
if e.conditionals:
nc = {d: sympy.Not(v) for d, v in e.conditionals.items()}
v = identity_mapper[e.lhs.dtype][e.operation]
processed.append(
e.func(a.indexify(), v, operation=None, conditionals=nc)
)

processed.append(e.func(lhs, a.indexify()))

for d in sequentialize:
properties = properties.sequentialize(d)
Expand Down
22 changes: 20 additions & 2 deletions devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from functools import cached_property

import numpy as np
import sympy

from devito.ir.equations.algorithms import dimension_sort, lower_exprs
from devito.finite_differences.differentiable import diff2sympy
from devito.ir.support import (GuardFactor, Interval, IntervalGroup, IterationSpace,
Stencil, detect_io, detect_accesses)
from devito.symbolics import IntDiv, uxreplace
from devito.symbolics import IntDiv, limits_mapper, uxreplace
from devito.tools import Pickable, Tag, frozendict
from devito.types import Eq, Inc, ReduceMax, ReduceMin, relational_min

__all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax']
__all__ = ['LoweredEq', 'ClusterizedEq', 'DummyEq', 'OpInc', 'OpMin', 'OpMax',
'identity_mapper']


class IREq(sympy.Eq, Pickable):
Expand Down Expand Up @@ -119,6 +121,22 @@ def detect(cls, expr):
OpMin = Operation('min')


identity_mapper = {
np.int32: {OpInc: sympy.S.Zero,
OpMax: limits_mapper[np.int32].min,
OpMin: limits_mapper[np.int32].max},
np.int64: {OpInc: sympy.S.Zero,
OpMax: limits_mapper[np.int64].min,
OpMin: limits_mapper[np.int64].max},
np.float32: {OpInc: sympy.S.Zero,
OpMax: limits_mapper[np.float32].min,
OpMin: limits_mapper[np.float32].max},
np.float64: {OpInc: sympy.S.Zero,
OpMax: limits_mapper[np.float64].min,
OpMin: limits_mapper[np.float64].max},
}


class LoweredEq(IREq):

"""
Expand Down

0 comments on commit 4526d35

Please sign in to comment.