Skip to content

Commit

Permalink
Merge pull request #1650 from devitocodes/revisit-uniqueness
Browse files Browse the repository at this point in the history
compiler: Singletonize special symbols (e.g. nthreads)
  • Loading branch information
FabioLuporini authored Apr 8, 2021
2 parents ea5c1f4 + 602ae81 commit 43788f3
Show file tree
Hide file tree
Showing 21 changed files with 207 additions and 118 deletions.
6 changes: 3 additions & 3 deletions devito/builtins/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def norm(f, order=2):
# otherwise we would eventually be summing more than expected
p, eqns = f.guard() if f.is_SparseFunction else (f, [])

s = dv.types.Scalar(name='sum', dtype=f.dtype)
s = dv.types.Symbol(name='sum', dtype=f.dtype)

with MPIReduction(f) as mr:
op = dv.Operator([dv.Eq(s, 0.0)] +
Expand Down Expand Up @@ -59,7 +59,7 @@ def sumall(f):
# otherwise we would eventually be summing more than expected
p, eqns = f.guard() if f.is_SparseFunction else (f, [])

s = dv.types.Scalar(name='sum', dtype=f.dtype)
s = dv.types.Symbol(name='sum', dtype=f.dtype)

with MPIReduction(f) as mr:
op = dv.Operator([dv.Eq(s, 0.0)] +
Expand Down Expand Up @@ -113,7 +113,7 @@ def inner(f, g):
# otherwise we would eventually be summing more than expected
rhs, eqns = f.guard(f*g) if f.is_SparseFunction else (f*g, [])

s = dv.types.Scalar(name='sum', dtype=f.dtype)
s = dv.types.Symbol(name='sum', dtype=f.dtype)

with MPIReduction(f, g) as mr:
op = dv.Operator([dv.Eq(s, 0.0)] +
Expand Down
2 changes: 1 addition & 1 deletion devito/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _autotune(self, args, setup):

@property
def nthreads(self):
nthreads = [i for i in self.input if type(i).__base__ is NThreads]
nthreads = [i for i in self.input if isinstance(i, NThreads)]
if len(nthreads) == 0:
return 1
else:
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/iet/efunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class ThreadFunction(Callable):
def _make_threads(value, sregistry):
name = sregistry.make_name(prefix='threads')

base_id = 1 + sum(i.data for i in sregistry.npthreads)
base_id = 1 + sum(i.size for i in sregistry.npthreads)

if value is None:
# The npthreads Symbol isn't actually used, but we record the fact
Expand Down
4 changes: 2 additions & 2 deletions devito/operations/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from devito.symbolics import retrieve_function_carriers, indexify, INT
from devito.tools import powerset, flatten, prod
from devito.types import (ConditionalDimension, Dimension, DefaultDimension, Eq, Inc,
Evaluable, Scalar, SubFunction)
Evaluable, Symbol, SubFunction)

__all__ = ['LinearInterpolator', 'PrecomputedInterpolator']

Expand Down Expand Up @@ -234,7 +234,7 @@ def callback():
for b, v_sub in zip(self._interpolation_coeffs, idx_subs)]

# Accumulate point-wise contributions into a temporary
rhs = Scalar(name='sum', dtype=self.sfunction.dtype)
rhs = Symbol(name='sum', dtype=self.sfunction.dtype)
summands = [Eq(rhs, 0., implicit_dims=self.sfunction.dimensions)]
summands.extend([Inc(rhs, i, implicit_dims=self.sfunction.dimensions)
for i in args])
Expand Down
10 changes: 5 additions & 5 deletions devito/operator/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def __init__(self):
self.counters = {}

# Special symbols
self.nthreads = NThreads(aliases='nthreads0')
self.nthreads_nested = NThreadsNested(aliases='nthreads1')
self.nthreads_nonaffine = NThreadsNonaffine(aliases='nthreads2')
self.nthreads = NThreads()
self.nthreads_nested = NThreadsNested()
self.nthreads_nonaffine = NThreadsNonaffine()
self.threadid = ThreadID(self.nthreads)

# Several groups of pthreads each of size `npthread` may be created
Expand All @@ -36,8 +36,8 @@ def make_name(self, prefix=None):

return "%s%d" % (prefix, counter())

def make_npthreads(self, value):
def make_npthreads(self, size):
name = self.make_name(prefix='npthreads')
npthreads = NPThreads(name=name, value=value)
npthreads = NPThreads(name=name, size=size)
self.npthreads.append(npthreads)
return npthreads
6 changes: 3 additions & 3 deletions devito/passes/clusters/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from devito.symbolics import (Uxmapper, compare_ops, estimate_cost, q_constant,
q_leaf, retrieve_indexed, search, uxreplace)
from devito.tools import as_tuple, flatten, split
from devito.types import (Array, TempFunction, Eq, Scalar, ModuloDimension,
from devito.types import (Array, TempFunction, Eq, Symbol, ModuloDimension,
CustomDimension, IncrDimension)

__all__ = ['cire']
Expand Down Expand Up @@ -172,7 +172,7 @@ def make_schedule(self, cluster, context):
return SpacePoint(schedule, exprs, score)

def _make_symbol(self):
return Scalar(name=self.sregistry.make_name('dummy'))
return Symbol(name=self.sregistry.make_name('dummy'))

def _nrepeats(self, cluster):
raise NotImplementedError
Expand Down Expand Up @@ -801,7 +801,7 @@ def lower_schedule(cluster, schedule, sregistry, options):
# Degenerate case: scalar expression
assert writeto.size == 0

obj = Scalar(name=name, dtype=dtype)
obj = Symbol(name=name, dtype=dtype)
expression = Eq(obj, alias)

callback = lambda idx: obj
Expand Down
4 changes: 2 additions & 2 deletions devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from devito.ir import DummyEq, Cluster, Scope
from devito.passes.clusters.utils import cluster_pass, makeit_ssa
from devito.symbolics import count, estimate_cost, q_xop, q_leaf, uxreplace
from devito.types import Scalar
from devito.types import Symbol

__all__ = ['cse']

Expand All @@ -13,7 +13,7 @@ def cse(cluster, sregistry, *args):
"""
Common sub-expressions elimination (CSE).
"""
make = lambda: Scalar(name=sregistry.make_name(), dtype=cluster.dtype).indexify()
make = lambda: Symbol(name=sregistry.make_name(), dtype=cluster.dtype).indexify()
processed = _cse(cluster.exprs, make)

return cluster.rebuild(processed)
Expand Down
8 changes: 4 additions & 4 deletions devito/passes/clusters/factorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sympy import Add, Mul, collect

from devito.passes.clusters.utils import cluster_pass
from devito.symbolics import estimate_cost, retrieve_scalars
from devito.symbolics import estimate_cost, retrieve_symbols
from devito.tools import ReducerMap

__all__ = ['factorize']
Expand Down Expand Up @@ -155,9 +155,9 @@ def run(expr):

# Collect common temporaries (r0, r1, ...)
w_coeffs = Add(*w_coeffs, evaluate=False)
scalars = retrieve_scalars(w_coeffs)
if scalars:
w_coeffs = collect(w_coeffs, scalars, evaluate=False)
symbols = retrieve_symbols(w_coeffs)
if symbols:
w_coeffs = collect(w_coeffs, symbols, evaluate=False)
try:
terms.extend([Mul(k, collect_const(v), evaluate=False)
for k, v in w_coeffs.items()])
Expand Down
4 changes: 2 additions & 2 deletions devito/passes/clusters/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from devito.passes.clusters.utils import cluster_pass
from devito.symbolics import pow_to_mul, uxreplace
from devito.tools import DAG, as_tuple, filter_ordered, frozendict, timed_pass
from devito.types import Scalar
from devito.types import Symbol

__all__ = ['Lift', 'fuse', 'eliminate_arrays', 'optimize_pows', 'extract_increments']

Expand Down Expand Up @@ -316,7 +316,7 @@ def extract_increments(cluster, sregistry, *args):
processed = []
for e in cluster.exprs:
if e.is_Increment and e.lhs.function.is_Input:
handle = Scalar(name=sregistry.make_name(), dtype=e.dtype).indexify()
handle = Symbol(name=sregistry.make_name(), dtype=e.dtype).indexify()
if e.rhs.is_Number or e.rhs.is_Symbol:
extracted = e.rhs
else:
Expand Down
2 changes: 1 addition & 1 deletion devito/passes/iet/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def filter_args(v, efunc=None):
continue

if efunc is self.root and not (a.is_Input or a.is_Object):
# Temporaries (ie, Scalars, Arrays) *cannot* be args in `root`
# Temporaries (ie, Symbol, Arrays) *cannot* be args in `root`
continue

processed.append(a)
Expand Down
4 changes: 2 additions & 2 deletions devito/passes/iet/parpragma.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from devito.passes.iet.langbase import LangBB, LangTransformer, DeviceAwareMixin
from devito.passes.iet.misc import is_on_device
from devito.tools import as_tuple, is_integer, prod
from devito.types import Symbol, NThreadsMixin
from devito.types import Symbol, NThreadsBase

__all__ = ['PragmaSimdTransformer', 'PragmaShmTransformer',
'PragmaDeviceAwareTransformer', 'PragmaLangBB']
Expand Down Expand Up @@ -360,7 +360,7 @@ def _make_parallel(self, iet):
iet = Transformer(mapper).visit(iet)

# The new arguments introduced by this pass
args = [i for i in FindSymbols().visit(iet) if isinstance(i, (NThreadsMixin))]
args = [i for i in FindSymbols().visit(iet) if isinstance(i, (NThreadsBase))]
for n in FindNodes(VExpanded).visit(iet):
args.extend([(n.pointee, True), n.pointer])

Expand Down
6 changes: 3 additions & 3 deletions devito/symbolics/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

__all__ = ['q_leaf', 'q_indexed', 'q_terminal', 'q_function', 'q_routine', 'q_xop',
'q_terminalop', 'q_indirect', 'q_constant', 'q_affine', 'q_linear',
'q_identity', 'q_inc', 'q_scalar', 'q_multivar', 'q_monoaffine',
'q_identity', 'q_inc', 'q_symbol', 'q_multivar', 'q_monoaffine',
'q_dimension']


Expand All @@ -16,9 +16,9 @@
# * Indexed


def q_scalar(expr):
def q_symbol(expr):
try:
return expr.is_Scalar
return expr.is_Symbol
except AttributeError:
return False

Expand Down
8 changes: 4 additions & 4 deletions devito/symbolics/search.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from devito.symbolics.queries import (q_indexed, q_function, q_terminal, q_leaf, q_xop,
q_scalar, q_dimension)
q_symbol, q_dimension)
from devito.tools import as_tuple

__all__ = ['retrieve_indexed', 'retrieve_functions', 'retrieve_function_carriers',
'retrieve_terminals', 'retrieve_xops', 'retrieve_scalars',
'retrieve_terminals', 'retrieve_xops', 'retrieve_symbols',
'retrieve_dimensions', 'search']


Expand Down Expand Up @@ -139,9 +139,9 @@ def retrieve_functions(exprs, mode='all'):
return search(exprs, q_function, mode, 'dfs')


def retrieve_scalars(exprs, mode='all'):
def retrieve_symbols(exprs, mode='all'):
"""Shorthand to retrieve the Scalar in ``exprs``."""
return search(exprs, q_scalar, mode, 'dfs')
return search(exprs, q_symbol, mode, 'dfs')


def retrieve_function_carriers(exprs, mode='all'):
Expand Down
20 changes: 17 additions & 3 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,14 +466,28 @@ class Scalar(Symbol, ArgProvider):
def __dtype_setup__(cls, **kwargs):
return kwargs.get('dtype', np.float32)

def _arg_defaults(self):
return {}
@property
def default_value(self):
return None

@property
def _arg_names(self):
return (self.name,)

def _arg_defaults(self, **kwargs):
if self.default_value is None:
# It is possible that the Scalar value is provided indirectly
# through a wrapper object (e.g., a Dimension spacing `h_x` gets its
# value via a Grid object)
return {}
else:
return {self.name: self.default_value}

def _arg_values(self, **kwargs):
if self.name in kwargs:
return {self.name: kwargs.pop(self.name)}
else:
return {}
return self._arg_defaults()


class AbstractTensor(sympy.ImmutableDenseMatrix, Basic, Pickable, Evaluable):
Expand Down
Loading

0 comments on commit 43788f3

Please sign in to comment.