From 97779f3ef78df6b71e20628fc81fcbb3fcfb4a32 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 1 Apr 2021 15:32:25 +0200 Subject: [PATCH 1/2] types: Singletonize special symbols (e.g. nthreads) --- devito/builtins/arithmetic.py | 6 +- devito/core/operator.py | 2 +- devito/operations/interpolators.py | 4 +- devito/operator/symbols.py | 6 +- devito/passes/clusters/aliases.py | 6 +- devito/passes/clusters/cse.py | 4 +- devito/passes/clusters/factorization.py | 8 +- devito/passes/clusters/misc.py | 4 +- devito/passes/iet/engine.py | 2 +- devito/passes/iet/parpragma.py | 4 +- devito/symbolics/queries.py | 6 +- devito/symbolics/search.py | 8 +- devito/types/basic.py | 20 ++++- devito/types/parallel.py | 81 ++++++++------------ devito/types/sparse.py | 6 +- examples/seismic/viscoacoustic/wavesolver.py | 9 +-- tests/test_caching.py | 34 +++++++- tests/test_dle.py | 26 ++++--- tests/test_gpu_common.py | 4 +- tests/test_pickle.py | 11 ++- 20 files changed, 141 insertions(+), 110 deletions(-) diff --git a/devito/builtins/arithmetic.py b/devito/builtins/arithmetic.py index 2ba9818476..0225ef3d77 100644 --- a/devito/builtins/arithmetic.py +++ b/devito/builtins/arithmetic.py @@ -28,7 +28,7 @@ def norm(f, order=2): # otherwise we would eventually be summing more than expected p, eqns = f.guard() if f.is_SparseFunction else (f, []) - s = dv.types.Scalar(name='sum', dtype=f.dtype) + s = dv.types.Symbol(name='sum', dtype=f.dtype) with MPIReduction(f) as mr: op = dv.Operator([dv.Eq(s, 0.0)] + @@ -59,7 +59,7 @@ def sumall(f): # otherwise we would eventually be summing more than expected p, eqns = f.guard() if f.is_SparseFunction else (f, []) - s = dv.types.Scalar(name='sum', dtype=f.dtype) + s = dv.types.Symbol(name='sum', dtype=f.dtype) with MPIReduction(f) as mr: op = dv.Operator([dv.Eq(s, 0.0)] + @@ -113,7 +113,7 @@ def inner(f, g): # otherwise we would eventually be summing more than expected rhs, eqns = f.guard(f*g) if f.is_SparseFunction else (f*g, []) - s = dv.types.Scalar(name='sum', dtype=f.dtype) + s = dv.types.Symbol(name='sum', dtype=f.dtype) with MPIReduction(f, g) as mr: op = dv.Operator([dv.Eq(s, 0.0)] + diff --git a/devito/core/operator.py b/devito/core/operator.py index d2105c635a..e2910f2146 100644 --- a/devito/core/operator.py +++ b/devito/core/operator.py @@ -64,7 +64,7 @@ def _autotune(self, args, setup): @property def nthreads(self): - nthreads = [i for i in self.input if type(i).__base__ is NThreads] + nthreads = [i for i in self.input if isinstance(i, NThreads)] if len(nthreads) == 0: return 1 else: diff --git a/devito/operations/interpolators.py b/devito/operations/interpolators.py index a11899a7f9..bc49c498cb 100644 --- a/devito/operations/interpolators.py +++ b/devito/operations/interpolators.py @@ -8,7 +8,7 @@ from devito.symbolics import retrieve_function_carriers, indexify, INT from devito.tools import powerset, flatten, prod from devito.types import (ConditionalDimension, Dimension, DefaultDimension, Eq, Inc, - Evaluable, Scalar, SubFunction) + Evaluable, Symbol, SubFunction) __all__ = ['LinearInterpolator', 'PrecomputedInterpolator'] @@ -234,7 +234,7 @@ def callback(): for b, v_sub in zip(self._interpolation_coeffs, idx_subs)] # Accumulate point-wise contributions into a temporary - rhs = Scalar(name='sum', dtype=self.sfunction.dtype) + rhs = Symbol(name='sum', dtype=self.sfunction.dtype) summands = [Eq(rhs, 0., implicit_dims=self.sfunction.dimensions)] summands.extend([Inc(rhs, i, implicit_dims=self.sfunction.dimensions) for i in args]) diff --git a/devito/operator/symbols.py b/devito/operator/symbols.py index 3f7037c680..19f46f9d15 100644 --- a/devito/operator/symbols.py +++ b/devito/operator/symbols.py @@ -15,9 +15,9 @@ def __init__(self): self.counters = {} # Special symbols - self.nthreads = NThreads(aliases='nthreads0') - self.nthreads_nested = NThreadsNested(aliases='nthreads1') - self.nthreads_nonaffine = NThreadsNonaffine(aliases='nthreads2') + self.nthreads = NThreads() + self.nthreads_nested = NThreadsNested() + self.nthreads_nonaffine = NThreadsNonaffine() self.threadid = ThreadID(self.nthreads) # Several groups of pthreads each of size `npthread` may be created diff --git a/devito/passes/clusters/aliases.py b/devito/passes/clusters/aliases.py index 9e36a91a47..7901a15e89 100644 --- a/devito/passes/clusters/aliases.py +++ b/devito/passes/clusters/aliases.py @@ -13,7 +13,7 @@ from devito.symbolics import (Uxmapper, compare_ops, estimate_cost, q_constant, q_leaf, retrieve_indexed, search, uxreplace) from devito.tools import as_tuple, flatten, split -from devito.types import (Array, TempFunction, Eq, Scalar, ModuloDimension, +from devito.types import (Array, TempFunction, Eq, Symbol, ModuloDimension, CustomDimension, IncrDimension) __all__ = ['cire'] @@ -172,7 +172,7 @@ def make_schedule(self, cluster, context): return SpacePoint(schedule, exprs, score) def _make_symbol(self): - return Scalar(name=self.sregistry.make_name('dummy')) + return Symbol(name=self.sregistry.make_name('dummy')) def _nrepeats(self, cluster): raise NotImplementedError @@ -801,7 +801,7 @@ def lower_schedule(cluster, schedule, sregistry, options): # Degenerate case: scalar expression assert writeto.size == 0 - obj = Scalar(name=name, dtype=dtype) + obj = Symbol(name=name, dtype=dtype) expression = Eq(obj, alias) callback = lambda idx: obj diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index 5b559c8631..d4a9a1b8dd 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -3,7 +3,7 @@ from devito.ir import DummyEq, Cluster, Scope from devito.passes.clusters.utils import cluster_pass, makeit_ssa from devito.symbolics import count, estimate_cost, q_xop, q_leaf, uxreplace -from devito.types import Scalar +from devito.types import Symbol __all__ = ['cse'] @@ -13,7 +13,7 @@ def cse(cluster, sregistry, *args): """ Common sub-expressions elimination (CSE). """ - make = lambda: Scalar(name=sregistry.make_name(), dtype=cluster.dtype).indexify() + make = lambda: Symbol(name=sregistry.make_name(), dtype=cluster.dtype).indexify() processed = _cse(cluster.exprs, make) return cluster.rebuild(processed) diff --git a/devito/passes/clusters/factorization.py b/devito/passes/clusters/factorization.py index 0d54911185..8468a473e1 100644 --- a/devito/passes/clusters/factorization.py +++ b/devito/passes/clusters/factorization.py @@ -3,7 +3,7 @@ from sympy import Add, Mul, collect from devito.passes.clusters.utils import cluster_pass -from devito.symbolics import estimate_cost, retrieve_scalars +from devito.symbolics import estimate_cost, retrieve_symbols from devito.tools import ReducerMap __all__ = ['factorize'] @@ -155,9 +155,9 @@ def run(expr): # Collect common temporaries (r0, r1, ...) w_coeffs = Add(*w_coeffs, evaluate=False) - scalars = retrieve_scalars(w_coeffs) - if scalars: - w_coeffs = collect(w_coeffs, scalars, evaluate=False) + symbols = retrieve_symbols(w_coeffs) + if symbols: + w_coeffs = collect(w_coeffs, symbols, evaluate=False) try: terms.extend([Mul(k, collect_const(v), evaluate=False) for k, v in w_coeffs.items()]) diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index 440be0b2e8..28898cd231 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -6,7 +6,7 @@ from devito.passes.clusters.utils import cluster_pass from devito.symbolics import pow_to_mul, uxreplace from devito.tools import DAG, as_tuple, filter_ordered, frozendict, timed_pass -from devito.types import Scalar +from devito.types import Symbol __all__ = ['Lift', 'fuse', 'eliminate_arrays', 'optimize_pows', 'extract_increments'] @@ -316,7 +316,7 @@ def extract_increments(cluster, sregistry, *args): processed = [] for e in cluster.exprs: if e.is_Increment and e.lhs.function.is_Input: - handle = Scalar(name=sregistry.make_name(), dtype=e.dtype).indexify() + handle = Symbol(name=sregistry.make_name(), dtype=e.dtype).indexify() if e.rhs.is_Number or e.rhs.is_Symbol: extracted = e.rhs else: diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index 281584b18b..7e38848195 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -118,7 +118,7 @@ def filter_args(v, efunc=None): continue if efunc is self.root and not (a.is_Input or a.is_Object): - # Temporaries (ie, Scalars, Arrays) *cannot* be args in `root` + # Temporaries (ie, Symbol, Arrays) *cannot* be args in `root` continue processed.append(a) diff --git a/devito/passes/iet/parpragma.py b/devito/passes/iet/parpragma.py index 21c93ec988..9061fd8083 100644 --- a/devito/passes/iet/parpragma.py +++ b/devito/passes/iet/parpragma.py @@ -12,7 +12,7 @@ from devito.passes.iet.langbase import LangBB, LangTransformer, DeviceAwareMixin from devito.passes.iet.misc import is_on_device from devito.tools import as_tuple, is_integer, prod -from devito.types import Symbol, NThreadsMixin +from devito.types import Symbol, NThreadsBase __all__ = ['PragmaSimdTransformer', 'PragmaShmTransformer', 'PragmaDeviceAwareTransformer', 'PragmaLangBB'] @@ -360,7 +360,7 @@ def _make_parallel(self, iet): iet = Transformer(mapper).visit(iet) # The new arguments introduced by this pass - args = [i for i in FindSymbols().visit(iet) if isinstance(i, (NThreadsMixin))] + args = [i for i in FindSymbols().visit(iet) if isinstance(i, (NThreadsBase))] for n in FindNodes(VExpanded).visit(iet): args.extend([(n.pointee, True), n.pointer]) diff --git a/devito/symbolics/queries.py b/devito/symbolics/queries.py index feec5bcf41..432bb2bfcb 100644 --- a/devito/symbolics/queries.py +++ b/devito/symbolics/queries.py @@ -5,7 +5,7 @@ __all__ = ['q_leaf', 'q_indexed', 'q_terminal', 'q_function', 'q_routine', 'q_xop', 'q_terminalop', 'q_indirect', 'q_constant', 'q_affine', 'q_linear', - 'q_identity', 'q_inc', 'q_scalar', 'q_multivar', 'q_monoaffine', + 'q_identity', 'q_inc', 'q_symbol', 'q_multivar', 'q_monoaffine', 'q_dimension'] @@ -16,9 +16,9 @@ # * Indexed -def q_scalar(expr): +def q_symbol(expr): try: - return expr.is_Scalar + return expr.is_Symbol except AttributeError: return False diff --git a/devito/symbolics/search.py b/devito/symbolics/search.py index 313edf1f63..60d54d5af9 100644 --- a/devito/symbolics/search.py +++ b/devito/symbolics/search.py @@ -1,9 +1,9 @@ from devito.symbolics.queries import (q_indexed, q_function, q_terminal, q_leaf, q_xop, - q_scalar, q_dimension) + q_symbol, q_dimension) from devito.tools import as_tuple __all__ = ['retrieve_indexed', 'retrieve_functions', 'retrieve_function_carriers', - 'retrieve_terminals', 'retrieve_xops', 'retrieve_scalars', + 'retrieve_terminals', 'retrieve_xops', 'retrieve_symbols', 'retrieve_dimensions', 'search'] @@ -139,9 +139,9 @@ def retrieve_functions(exprs, mode='all'): return search(exprs, q_function, mode, 'dfs') -def retrieve_scalars(exprs, mode='all'): +def retrieve_symbols(exprs, mode='all'): """Shorthand to retrieve the Scalar in ``exprs``.""" - return search(exprs, q_scalar, mode, 'dfs') + return search(exprs, q_symbol, mode, 'dfs') def retrieve_function_carriers(exprs, mode='all'): diff --git a/devito/types/basic.py b/devito/types/basic.py index 205860362d..befdb66370 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -466,14 +466,28 @@ class Scalar(Symbol, ArgProvider): def __dtype_setup__(cls, **kwargs): return kwargs.get('dtype', np.float32) - def _arg_defaults(self): - return {} + @property + def default_value(self): + return None + + @property + def _arg_names(self): + return (self.name,) + + def _arg_defaults(self, **kwargs): + if self.default_value is None: + # It is possible that the Scalar value is provided indirectly + # through a wrapper object (e.g., a Dimension spacing `h_x` gets its + # value via a Grid object) + return {} + else: + return {self.name: self.default_value} def _arg_values(self, **kwargs): if self.name in kwargs: return {self.name: kwargs.pop(self.name)} else: - return {} + return self._arg_defaults() class AbstractTensor(sympy.ImmutableDenseMatrix, Basic, Pickable, Evaluable): diff --git a/devito/types/parallel.py b/devito/types/parallel.py index 70105575bf..96a86b3b48 100644 --- a/devito/types/parallel.py +++ b/devito/types/parallel.py @@ -17,77 +17,61 @@ from devito.parameters import configuration from devito.tools import Pickable, as_list, as_tuple, dtype_to_cstr, filter_ordered from devito.types.array import Array, ArrayObject -from devito.types.basic import Symbol -from devito.types.constant import Constant +from devito.types.basic import Scalar, Symbol from devito.types.dimension import CustomDimension from devito.types.misc import VolatileInt, c_volatile_int_p -__all__ = ['NThreads', 'NThreadsNested', 'NThreadsNonaffine', 'NThreadsMixin', 'DeviceID', +__all__ = ['NThreads', 'NThreadsNested', 'NThreadsNonaffine', 'NThreadsBase', 'DeviceID', 'ThreadID', 'Lock', 'WaitLock', 'WithLock', 'FetchWait', 'FetchWaitPrefetch', 'Delete', 'PThreadArray', 'SharedData', 'NPThreads', 'DeviceRM', 'normalize_syncs'] -class NThreadsMixin(object): +class NThreadsBase(Scalar): + is_Input = True is_PerfKnob = True def __new__(cls, **kwargs): - name = kwargs.get('name', cls.name) - value = cls.__value_setup__(**kwargs) - obj = Constant.__new__(cls, name=name, dtype=np.int32, value=value) - obj.aliases = as_tuple(kwargs.get('aliases')) + (name,) - return obj + kwargs.setdefault('name', cls.name) + kwargs['is_const'] = True + return super().__new__(cls, **kwargs) @classmethod - def __value_setup__(cls, **kwargs): - try: - return kwargs.pop('value') - except KeyError: - return cls.default_value() - - @property - def _arg_names(self): - return self.aliases + def __dtype_setup__(cls, **kwargs): + return np.int32 - def _arg_values(self, **kwargs): - for i in self.aliases: - if i in kwargs: - return {self.name: kwargs.pop(i)} - # Fallback: as usual, pick the default value - return self._arg_defaults() + @cached_property + def default_value(self): + return int(os.environ.get('OMP_NUM_THREADS', + configuration['platform'].cores_physical)) -class NThreads(NThreadsMixin, Constant): +class NThreads(NThreadsBase): name = 'nthreads' - @classmethod - def default_value(cls): - return int(os.environ.get('OMP_NUM_THREADS', - configuration['platform'].cores_physical)) - -class NThreadsNested(NThreadsMixin, Constant): +class NThreadsNonaffine(NThreadsBase): - name = 'nthreads_nested' + name = 'nthreads_nonaffine' - @classmethod - def default_value(cls): - return configuration['platform'].threads_per_core +class NThreadsNested(NThreadsBase): -class NThreadsNonaffine(NThreads): + name = 'nthreads_nested' - name = 'nthreads_nonaffine' + @property + def default_value(self): + return configuration['platform'].threads_per_core -class NPThreads(NThreadsMixin, Constant): +class NPThreads(NThreadsBase): name = 'npthreads' - @classmethod - def default_value(cls): + @property + def default_value(self): return 1 @@ -344,14 +328,15 @@ def normalize_syncs(*args): return syncs -class DeviceSymbol(Constant): +class DeviceSymbol(Scalar): + is_Input = True is_PerfKnob = True - def __new__(cls, *args, **kwargs): + def __new__(cls, **kwargs): kwargs['name'] = cls.name - kwargs['value'] = cls.__value_setup__(**kwargs) - return Constant.__new__(cls, *args, **kwargs) + kwargs['is_const'] = True + return super().__new__(cls, **kwargs) @classmethod def __dtype_setup__(cls, **kwargs): @@ -362,8 +347,8 @@ class DeviceID(DeviceSymbol): name = 'deviceid' - @classmethod - def __value_setup__(cls, **kwargs): + @property + def default_value(self): return -1 @@ -371,8 +356,8 @@ class DeviceRM(DeviceSymbol): name = 'devicerm' - @classmethod - def __value_setup__(cls, **kwargs): + @property + def default_value(self): return 1 def _arg_values(self, **kwargs): diff --git a/devito/types/sparse.py b/devito/types/sparse.py index cb07e02408..46519e7608 100644 --- a/devito/types/sparse.py +++ b/devito/types/sparse.py @@ -14,7 +14,7 @@ memoized_meth, is_integer) from devito.types.dense import DiscreteFunction, Function, SubFunction from devito.types.dimension import Dimension, ConditionalDimension, DefaultDimension -from devito.types.basic import Symbol, Scalar +from devito.types.basic import Symbol from devito.types.equation import Eq, Inc __all__ = ['SparseFunction', 'SparseTimeFunction', 'PrecomputedSparseFunction', @@ -540,7 +540,7 @@ def coordinates_data(self): @cached_property def _point_symbols(self): """Symbol for coordinate value in each dimension of the point.""" - return tuple(Scalar(name='p%s' % d, dtype=self.dtype) + return tuple(Symbol(name='p%s' % d, dtype=self.dtype) for d in self.grid.dimensions) @cached_property @@ -559,7 +559,7 @@ def _position_map(self): the position. We mitigate this problem by computing the positions individually (hence the need for a position map). """ - symbols = [Scalar(name='pos%s' % d, dtype=self.dtype) + symbols = [Symbol(name='pos%s' % d, dtype=self.dtype) for d in self.grid.dimensions] return OrderedDict([(c - o, p) for p, c, o in zip(symbols, self._coordinate_symbols, diff --git a/examples/seismic/viscoacoustic/wavesolver.py b/examples/seismic/viscoacoustic/wavesolver.py index 0ff9b9b2c0..b9e51b16d4 100755 --- a/examples/seismic/viscoacoustic/wavesolver.py +++ b/examples/seismic/viscoacoustic/wavesolver.py @@ -191,6 +191,7 @@ def adjoint(self, rec, srca=None, va=None, pa=None, r=None, model=None, **kwargs time_order=self.time_order, space_order=self.space_order) kwargs.update({k.name: k for k in va}) + kwargs['time_m'] = 0 pa = pa or TimeFunction(name="pa", grid=self.model.grid, time_order=self.time_order, space_order=self.space_order, @@ -209,14 +210,10 @@ def adjoint(self, rec, srca=None, va=None, pa=None, r=None, model=None, **kwargs # Execute operator and return wavefield and receiver data # With Memory variable summary = self.op_adj().apply(src=srca, rec=rec, pa=pa, r=r, - dt=kwargs.pop('dt', self.dt), - time_m=0 if self.time_order == 1 else None, - **kwargs) + dt=kwargs.pop('dt', self.dt), **kwargs) else: summary = self.op_adj().apply(src=srca, rec=rec, pa=pa, - dt=kwargs.pop('dt', self.dt), - time_m=0 if self.time_order == 1 else None, - **kwargs) + dt=kwargs.pop('dt', self.dt), **kwargs) return srca, pa, va, summary def jacobian_adjoint(self, rec, p, pa=None, grad=None, r=None, model=None, diff --git a/tests/test_caching.py b/tests/test_caching.py index 00a0aabc27..6318b4619b 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -7,7 +7,7 @@ ConditionalDimension, SubDimension, Constant, Operator, Eq, Dimension, DefaultDimension, _SymbolCache, clear_cache, solve, VectorFunction, TensorFunction, TensorTimeFunction, VectorTimeFunction) -from devito.types.basic import Scalar, Symbol +from devito.types import Scalar, Symbol, NThreadsBase, DeviceID, ThreadID @pytest.fixture @@ -377,6 +377,38 @@ def test_grid_objs(self): assert ox0 is ox1 assert oy0 is oy1 + def test_special_symbols(self): + """ + This test checks the singletonization, through the caching infrastructure, + of the special symbols that an Operator may generate (e.g., `nthreads`). + """ + grid = Grid(shape=(4, 4, 4)) + f = TimeFunction(name='f', grid=grid) + sf = SparseTimeFunction(name='sf', grid=grid, npoint=1, nt=10) + + eqns = [Eq(f.forward, f + 1.)] + sf.inject(field=f.forward, expr=sf) + + opt = ('advanced', {'par-nested': 0, 'openmp': True}) + op0 = Operator(eqns, opt=opt) + op1 = Operator(eqns, opt=opt) + + nthreads0, nthreads_nested0, nthreads_nonaffine0 =\ + [i for i in op0.input if isinstance(i, NThreadsBase)] + nthreads1, nthreads_nested1, nthreads_nonaffine1 =\ + [i for i in op1.input if isinstance(i, NThreadsBase)] + + assert nthreads0 is nthreads1 + assert nthreads_nested0 is nthreads_nested1 + assert nthreads_nonaffine0 is nthreads_nonaffine1 + + tid0 = ThreadID(op0.nthreads) + tid1 = ThreadID(op0.nthreads) + assert tid0 is tid1 + + did0 = DeviceID() + did1 = DeviceID() + assert did0 is did1 + def test_symbol_aliasing(self): """Test to assert that our aliasing cache isn't defeated by sympys non-aliasing symbol cache. diff --git a/tests/test_dle.py b/tests/test_dle.py index dba682b3cd..7d9fe837c9 100644 --- a/tests/test_dle.py +++ b/tests/test_dle.py @@ -10,7 +10,7 @@ from devito.ir.iet import Call, Iteration, Conditional, FindNodes, retrieve_iteration_tree from devito.passes.iet.languages.openmp import OmpRegion from devito.tools import as_tuple -from devito.types import Scalar, NThreads, NThreadsNonaffine +from devito.types import Scalar def get_blocksizes(op, opt, grid, blockshape, level=0): @@ -360,6 +360,21 @@ def test_cache_blocking_imperfect_nest_v2(blockinner): class TestNodeParallelism(object): + def test_nthreads_generation(self): + grid = Grid(shape=(10, 10)) + + f = TimeFunction(name='f', grid=grid) + + eq = Eq(f.forward, f + 1) + + op0 = Operator(eq, openmp=True) + + # `nthreads` must appear among the Operator parameters + assert op0.nthreads in op0.parameters + + # `nthreads` is bindable to a runtime value + assert op0.nthreads._arg_values() + @pytest.mark.parametrize('exprs,expected', [ # trivial 1D (['Eq(fa[x], fa[x] + fb[x])'], @@ -452,15 +467,8 @@ def test_dynamic_nthreads(self): assert np.all(f.data[0] == 2.) # Check the actual value assumed by `nthreads` and `nthreads_nonaffine` - assert op.arguments(time=0)['nthreads'] == NThreads.default_value() - assert op.arguments(time=0)['nthreads_nonaffine'] == \ - NThreadsNonaffine.default_value() - # Again, but with user-supplied values assert op.arguments(time=0, nthreads=123)['nthreads'] == 123 assert op.arguments(time=0, nthreads_nonaffine=100)['nthreads_nonaffine'] == 100 - # Again, but with the aliases - assert op.arguments(time=0, nthreads0=123)['nthreads'] == 123 - assert op.arguments(time=0, nthreads2=123)['nthreads_nonaffine'] == 123 @pytest.mark.parametrize('eqns,expected,blocking', [ ('[Eq(f, 2*f)]', [2, 0, 0], False), @@ -633,8 +641,6 @@ def test_basic(self): op.apply(t_M=9, nthreads=1, nthreads_nested=2) assert np.all(u.data[0] == 10) assert op.arguments(t_M=9, nthreads_nested=2)['nthreads_nested'] == 2 - # Same as above, but with the alias - assert op.arguments(t_M=9, nthreads1=2)['nthreads_nested'] == 2 iterations = FindNodes(Iteration).visit(op._func_table['bf0']) assert iterations[0].pragmas[0].value == 'omp for collapse(1) schedule(dynamic,1)' diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 4d997a2fdc..7636340a81 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -727,7 +727,6 @@ def check_deviceid(self): deviceid = self.get_param(op, DeviceID) assert deviceid is not None - assert deviceid.data == -1 assert op.arguments()[deviceid.name] == -1 assert op.arguments(deviceid=0)[deviceid.name] == 0 @@ -748,8 +747,7 @@ def test_devicerm(self): devicerm = self.get_param(op, DeviceRM) assert devicerm is not None - assert devicerm.data == 1 # Always evict, by default - assert op.arguments(time_M=2)[devicerm.name] == 1 + assert op.arguments(time_M=2)[devicerm.name] == 1 # Always evict by default assert op.arguments(time_M=2, devicerm=0)[devicerm.name] == 0 assert op.arguments(time_M=2, devicerm=1)[devicerm.name] == 1 assert op.arguments(time_M=2, devicerm=224)[devicerm.name] == 1 diff --git a/tests/test_pickle.py b/tests/test_pickle.py index 4dc61b5202..fbf89c69cf 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -516,19 +516,18 @@ def test_compilerfunction(): assert new_pcf.ndim == cf.ndim + 1 -@skipif(['nompi']) -@pytest.mark.parallel(mode=[1]) def test_deviceid(): - grid = Grid(shape=(4, 4, 4)) - - did = DeviceID(grid.distributor._obj_comm) + did = DeviceID() pkl_did = pickle.dumps(did) new_did = pickle.loads(pkl_did) + # TODO: this will be extend when we'll support DeviceID + # for multi-node multi-gpu execution, when DeviceID will have + # to pick its default value from an MPI communicator attached + # to the runtime arguments assert did.name == new_did.name assert did.dtype == new_did.dtype - assert did.data == new_did.data @skipif(['nompi']) From 7b3888f57944ea7dcfd6f9d6a556e0335f7ebc58 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 6 Apr 2021 15:06:09 +0200 Subject: [PATCH 2/2] types: Fixup generation, caching and processing of NPThreads --- devito/ir/iet/efunc.py | 2 +- devito/operator/symbols.py | 4 ++-- devito/types/parallel.py | 25 ++++++++++++++++++++++--- tests/test_caching.py | 8 +++++++- tests/test_gpu_common.py | 32 +++++++++++++++++++++++++++----- tests/test_pickle.py | 13 ++++++++++++- 6 files changed, 71 insertions(+), 13 deletions(-) diff --git a/devito/ir/iet/efunc.py b/devito/ir/iet/efunc.py index 14a88d1744..66044b9e66 100644 --- a/devito/ir/iet/efunc.py +++ b/devito/ir/iet/efunc.py @@ -118,7 +118,7 @@ class ThreadFunction(Callable): def _make_threads(value, sregistry): name = sregistry.make_name(prefix='threads') - base_id = 1 + sum(i.data for i in sregistry.npthreads) + base_id = 1 + sum(i.size for i in sregistry.npthreads) if value is None: # The npthreads Symbol isn't actually used, but we record the fact diff --git a/devito/operator/symbols.py b/devito/operator/symbols.py index 19f46f9d15..1adb9d7949 100644 --- a/devito/operator/symbols.py +++ b/devito/operator/symbols.py @@ -36,8 +36,8 @@ def make_name(self, prefix=None): return "%s%d" % (prefix, counter()) - def make_npthreads(self, value): + def make_npthreads(self, size): name = self.make_name(prefix='npthreads') - npthreads = NPThreads(name=name, value=value) + npthreads = NPThreads(name=name, size=size) self.npthreads.append(npthreads) return npthreads diff --git a/devito/types/parallel.py b/devito/types/parallel.py index 96a86b3b48..ef142cbf5a 100644 --- a/devito/types/parallel.py +++ b/devito/types/parallel.py @@ -14,6 +14,7 @@ import numpy as np import sympy +from devito.exceptions import InvalidArgument from devito.parameters import configuration from devito.tools import Pickable, as_list, as_tuple, dtype_to_cstr, filter_ordered from devito.types.array import Array, ArrayObject @@ -70,9 +71,27 @@ class NPThreads(NThreadsBase): name = 'npthreads' - @property - def default_value(self): - return 1 + def __new__(cls, **kwargs): + obj = super().__new__(cls, **kwargs) + + # Size of the thread pool + obj.size = kwargs['size'] + + return obj + + def _arg_values(self, **kwargs): + if self.name in kwargs: + v = kwargs.pop(self.name) + if v < self.size: + return {self.name: v} + else: + raise InvalidArgument("Illegal `%s=%d`. It must be `%s<%d`" + % (self.name, v, self.name, self.size)) + else: + return self._arg_defaults() + + # Pickling support + _pickle_kwargs = NThreadsBase._pickle_kwargs + ['size'] class ThreadID(CustomDimension): diff --git a/tests/test_caching.py b/tests/test_caching.py index 6318b4619b..6c62770e58 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -7,7 +7,7 @@ ConditionalDimension, SubDimension, Constant, Operator, Eq, Dimension, DefaultDimension, _SymbolCache, clear_cache, solve, VectorFunction, TensorFunction, TensorTimeFunction, VectorTimeFunction) -from devito.types import Scalar, Symbol, NThreadsBase, DeviceID, ThreadID +from devito.types import Scalar, Symbol, NThreadsBase, DeviceID, NPThreads, ThreadID @pytest.fixture @@ -409,6 +409,12 @@ def test_special_symbols(self): did1 = DeviceID() assert did0 is did1 + npt0 = NPThreads(name='npt', size=3) + npt1 = NPThreads(name='npt', size=3) + npt2 = NPThreads(name='npt', size=4) + assert npt0 is npt1 + assert npt0 is not npt2 + def test_symbol_aliasing(self): """Test to assert that our aliasing cache isn't defeated by sympys non-aliasing symbol cache. diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 7636340a81..cc72039f85 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -4,9 +4,10 @@ from devito import (Constant, Eq, Inc, Grid, Function, ConditionalDimension, SubDimension, SubDomain, TimeFunction, Operator) from devito.arch import get_gpu_info +from devito.exceptions import InvalidArgument from devito.ir import Expression, Section, FindNodes, FindSymbols, retrieve_iteration_tree from devito.passes.iet.languages.openmp import OmpIteration -from devito.types import DeviceID, DeviceRM, Lock, PThreadArray +from devito.types import DeviceID, DeviceRM, Lock, NPThreads, PThreadArray from conftest import skipif @@ -406,7 +407,7 @@ def test_composite_streaming_tasking(self): threads = [i for i in symbols if isinstance(i, PThreadArray)] assert len(threads) == 2 assert threads[0].size == 1 - assert threads[1].size.data == 2 + assert threads[1].size.size == 2 op0.apply(time_M=nt-1) op1.apply(time_M=nt-1, u=u1, usave=usave1) @@ -437,7 +438,7 @@ def test_composite_buffering_tasking(self): assert len([i for i in symbols if isinstance(i, Lock)]) == 1 threads = [i for i in symbols if isinstance(i, PThreadArray)] assert len(threads) == 1 - assert threads[0].size.data == 1 + assert threads[0].size.size == 1 op0.apply(time_M=nt-1, dt=0.1) op1.apply(time_M=nt-1, dt=0.1, u=u1, usave=usave1) @@ -475,8 +476,8 @@ def test_composite_buffering_tasking_multi_output(self): assert len([i for i in symbols if isinstance(i, Lock)]) == 2 threads = [i for i in symbols if isinstance(i, PThreadArray)] assert len(threads) == 2 - assert threads[0].size.data == 1 - assert threads[1].size.data == 1 + assert threads[0].size.size == 1 + assert threads[1].size.size == 1 assert len(op1._func_table) == 4 # usave and vsave eqns are in two diff efuncs op0.apply(time_M=nt-1) @@ -753,3 +754,24 @@ def test_devicerm(self): assert op.arguments(time_M=2, devicerm=224)[devicerm.name] == 1 assert op.arguments(time_M=2, devicerm=True)[devicerm.name] == 1 assert op.arguments(time_M=2, devicerm=False)[devicerm.name] == 0 + + def test_npthreads(self): + nt = 10 + async_degree = 5 + grid = Grid(shape=(300, 300, 300)) + + u = TimeFunction(name='u', grid=grid) + usave = TimeFunction(name='usave', grid=grid, save=nt) + + eqns = [Eq(u.forward, u + 1), + Eq(usave, u.forward)] + + op = Operator(eqns, opt=('buffering', 'tasking', 'orchestrate', + {'buf-async-degree': async_degree})) + + npthreads0 = self.get_param(op, NPThreads) + assert op.arguments(time_M=2)[npthreads0.name] == 1 + assert op.arguments(time_M=2, npthreads0=4)[npthreads0.name] == 4 + # Cannot provide a value larger than the thread pool size + with pytest.raises(InvalidArgument): + assert op.arguments(time_M=2, npthreads0=5) diff --git a/tests/test_pickle.py b/tests/test_pickle.py index fbf89c69cf..ec87674d53 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -13,7 +13,7 @@ MPIRegion) from devito.types import (Array, CustomDimension, Symbol as dSymbol, Scalar, PointerArray, Lock, PThreadArray, SharedData, Timer, - DeviceID, ThreadID, TempFunction) + DeviceID, NPThreads, ThreadID, TempFunction) from devito.symbolics import (IntDiv, ListInitializer, FieldFromPointer, FunctionFromPointer, DefFunction) from examples.seismic import (demo_model, AcquisitionGeometry, @@ -530,6 +530,17 @@ def test_deviceid(): assert did.dtype == new_did.dtype +def test_npthreads(): + npt = NPThreads(name='npt', size=3) + + pkl_npt = pickle.dumps(npt) + new_npt = pickle.loads(pkl_npt) + + assert npt.name == new_npt.name + assert npt.dtype == new_npt.dtype + assert npt.size == new_npt.size + + @skipif(['nompi']) @pytest.mark.parallel(mode=[(1, 'full')]) def test_mpi_fullmode_objects():