From cb919a6cbaa0cb12ffd68fc3e378955c56316398 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Sat, 26 Aug 2023 14:22:16 +0000 Subject: [PATCH 1/5] compiler: Simplify SubFunction --- devito/types/dense.py | 18 ++++-------------- devito/types/sparse.py | 23 ++++++++++++++++------- tests/test_interpolation.py | 1 - 3 files changed, 20 insertions(+), 22 deletions(-) diff --git a/devito/types/dense.py b/devito/types/dense.py index d9adfcedc3..4c912d2704 100644 --- a/devito/types/dense.py +++ b/devito/types/dense.py @@ -20,7 +20,7 @@ from devito.finite_differences import Differentiable, generate_fd_shortcuts from devito.tools import (ReducerMap, as_tuple, c_restrict_void_p, flatten, is_integer, memoized_meth, dtype_to_ctype, humanbytes) -from devito.types.dimension import Dimension +from devito.types.dimension import Dimension, DynamicDimension from devito.types.args import ArgProvider from devito.types.caching import CacheManager from devito.types.basic import AbstractFunction, Size @@ -1449,16 +1449,10 @@ class SubFunction(Function): """ A Function bound to a "parent" DiscreteFunction. - A SubFunction hands control of argument binding and halo exchange to its - parent DiscreteFunction. + A SubFunction hands control of argument binding and halo exchange to the + DiscreteFunction it's bound to. """ - __rkwargs__ = Function.__rkwargs__ + ('parent',) - - def __init_finalize__(self, *args, **kwargs): - super(SubFunction, self).__init_finalize__(*args, **kwargs) - self._parent = kwargs['parent'] - def __padding_setup__(self, **kwargs): # SubFunctions aren't expected to be used in time-consuming loops return tuple((0, 0) for i in range(self.ndim)) @@ -1470,12 +1464,8 @@ def _arg_values(self, **kwargs): if self.name in kwargs: raise RuntimeError("`%s` is a SubFunction, so it can't be assigned " "a value dynamically" % self.name) - else: - return self._parent._arg_defaults(alias=self._parent).reduce_all() - @property - def parent(self): - return self._parent + return self._arg_defaults(alias=self) @property def origin(self): diff --git a/devito/types/sparse.py b/devito/types/sparse.py index a1aef68f5f..4244c001ab 100644 --- a/devito/types/sparse.py +++ b/devito/types/sparse.py @@ -1179,6 +1179,15 @@ class PrecomputedSparseTimeFunction(AbstractSparseTimeFunction, PrecomputedSparseFunction.__rkwargs__)) +# *** MatrixSparse*Function API +# This is mostly legacy stuff which often escapes the devito's modus operandi + +class DynamicSubFunction(SubFunction): + + def _arg_defaults(self, **kwargs): + return {} + + class MatrixSparseTimeFunction(AbstractSparseTimeFunction): """ A specialised type of SparseTimeFunction where the interpolation is externally @@ -1378,7 +1387,7 @@ def __init_finalize__(self, *args, **kwargs): else: nnz_size = 1 - self._mrow = SubFunction( + self._mrow = DynamicSubFunction( name='mrow_%s' % self.name, dtype=np.int32, dimensions=(self.nnzdim,), @@ -1387,7 +1396,7 @@ def __init_finalize__(self, *args, **kwargs): parent=self, allocator=self._allocator, ) - self._mcol = SubFunction( + self._mcol = DynamicSubFunction( name='mcol_%s' % self.name, dtype=np.int32, dimensions=(self.nnzdim,), @@ -1396,7 +1405,7 @@ def __init_finalize__(self, *args, **kwargs): parent=self, allocator=self._allocator, ) - self._mval = SubFunction( + self._mval = DynamicSubFunction( name='mval_%s' % self.name, dtype=self.dtype, dimensions=(self.nnzdim,), @@ -1413,8 +1422,8 @@ def __init_finalize__(self, *args, **kwargs): self.par_dim_to_nnz_dim = DynamicDimension('par_dim_to_nnz_%s' % self.name) # This map acts as an indirect sort of the sources according to their - # position along the parallelisation Dimension - self._par_dim_to_nnz_map = SubFunction( + # position along the parallelisation dimension + self._par_dim_to_nnz_map = DynamicSubFunction( name='par_dim_to_nnz_map_%s' % self.name, dtype=np.int32, dimensions=(self.par_dim_to_nnz_dim,), @@ -1423,7 +1432,7 @@ def __init_finalize__(self, *args, **kwargs): space_order=0, parent=self, ) - self._par_dim_to_nnz_m = SubFunction( + self._par_dim_to_nnz_m = DynamicSubFunction( name='par_dim_to_nnz_m_%s' % self.name, dtype=np.int32, dimensions=(self._par_dim,), @@ -1432,7 +1441,7 @@ def __init_finalize__(self, *args, **kwargs): space_order=0, parent=self, ) - self._par_dim_to_nnz_M = SubFunction( + self._par_dim_to_nnz_M = DynamicSubFunction( name='par_dim_to_nnz_M_%s' % self.name, dtype=np.int32, dimensions=(self._par_dim,), diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index dca94c8f40..9a0608454a 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -689,7 +689,6 @@ def test_msf_interpolate(): eqn_inject = sf.inject(field=u, expr=sf) op2 = Operator(eqn_inject) - op2(time_m=0, time_M=4) # There should be 4 points touched for each source point From c4c839f6e328850a09fbf2ec084423cc439c8925 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 1 Aug 2023 12:39:43 +0000 Subject: [PATCH 2/5] compiler: Tweak pow_to_mul & factorize --- devito/passes/clusters/factorization.py | 17 +++++++------ devito/symbolics/queries.py | 9 ++++--- tests/test_dse.py | 32 +++++++++++++++++++++---- 3 files changed, 43 insertions(+), 15 deletions(-) diff --git a/devito/passes/clusters/factorization.py b/devito/passes/clusters/factorization.py index f023762402..1353bf89b3 100644 --- a/devito/passes/clusters/factorization.py +++ b/devito/passes/clusters/factorization.py @@ -142,13 +142,16 @@ def run(expr): terms.append(i) # Collect common funcs - w_funcs = Add(*w_funcs, evaluate=False) - w_funcs = collect(w_funcs, funcs, evaluate=False) - try: - terms.extend([Mul(k, collect_const(v), evaluate=False) - for k, v in w_funcs.items()]) - except AttributeError: - assert w_funcs == 0 + if len(w_funcs) > 1: + w_funcs = Add(*w_funcs, evaluate=False) + w_funcs = collect(w_funcs, funcs, evaluate=False) + try: + terms.extend([Mul(k, collect_const(v), evaluate=False) + for k, v in w_funcs.items()]) + except AttributeError: + assert w_funcs == 0 + else: + terms.extend(w_funcs) # Collect common pows w_pows = Add(*w_pows, evaluate=False) diff --git a/devito/symbolics/queries.py b/devito/symbolics/queries.py index ec86ae7809..c4002508cb 100644 --- a/devito/symbolics/queries.py +++ b/devito/symbolics/queries.py @@ -20,7 +20,8 @@ # * Number # * Symbol # * Indexed -extra_leaves = (FieldFromPointer, FieldFromComposite, IndexedBase, AbstractObject) +extra_leaves = (FieldFromPointer, FieldFromComposite, IndexedBase, AbstractObject, + IndexedPointer) def q_symbol(expr): @@ -31,7 +32,9 @@ def q_symbol(expr): def q_leaf(expr): - return expr.is_Atom or expr.is_Indexed or isinstance(expr, extra_leaves) + return (expr.is_Atom or + expr.is_Indexed or + isinstance(expr, extra_leaves)) def q_indexed(expr): @@ -51,7 +54,7 @@ def q_derivative(expr): def q_terminal(expr): return (expr.is_Symbol or expr.is_Indexed or - isinstance(expr, extra_leaves + (IndexedPointer,))) + isinstance(expr, extra_leaves)) def q_routine(expr): diff --git a/tests/test_dse.py b/tests/test_dse.py index 730021c3d8..2aefe69ed4 100644 --- a/tests/test_dse.py +++ b/tests/test_dse.py @@ -2,21 +2,27 @@ import pytest from cached_property import cached_property +from sympy import Mul # noqa + from conftest import (skipif, EVAL, _R, assert_structure, assert_blocking, # noqa get_params, get_arrays, check_array) -from devito import (NODE, Eq, Inc, Constant, Function, TimeFunction, SparseTimeFunction, # noqa - Dimension, SubDimension, ConditionalDimension, DefaultDimension, Grid, - Operator, norm, grad, div, dimensions, switchconfig, configuration, - centered, first_derivative, solve, transpose, Abs, cos, sin, sqrt) +from devito import (NODE, Eq, Inc, Constant, Function, TimeFunction, # noqa + SparseTimeFunction, Dimension, SubDimension, + ConditionalDimension, DefaultDimension, Grid, Operator, + norm, grad, div, dimensions, switchconfig, configuration, + centered, first_derivative, solve, transpose, Abs, cos, + sin, sqrt) from devito.exceptions import InvalidArgument, InvalidOperator from devito.finite_differences.differentiable import diffify from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes, FindSymbols, ParallelIteration, retrieve_iteration_tree) from devito.passes.clusters.aliases import collect +from devito.passes.clusters.factorization import collect_nested from devito.passes.clusters.cse import Temp, _cse from devito.passes.iet.parpragma import VExpanded from devito.symbolics import (INT, FLOAT, DefFunction, FieldFromPointer, # noqa - Keyword, SizeOf, estimate_cost, pow_to_mul, indexify) + IndexedPointer, Keyword, SizeOf, estimate_cost, + pow_to_mul, indexify) from devito.tools import as_tuple, generator from devito.types import Array, Scalar, Symbol @@ -161,6 +167,9 @@ def test_cse(exprs, expected, min_cost): ('fa[x]**(-s)', 'fa[x]**(-s)'), ('-2/(s**2)', '-2/(s*s)'), ('-fa[x]', '-fa[x]'), + ('Mul(SizeOf("char"), ' + '-IndexedPointer(FieldFromPointer("size", fa._C_symbol), x), evaluate=False)', + 'sizeof(char)*(-fa_vec->size[x])'), ]) def test_pow_to_mul(expr, expected): grid = Grid((4, 5)) @@ -173,6 +182,19 @@ def test_pow_to_mul(expr, expected): assert str(pow_to_mul(eval(expr))) == expected +@pytest.mark.parametrize('expr,expected', [ + ('s - SizeOf("int")*fa[x]', 's - fa[x]*sizeof(int)'), +]) +def test_factorize(expr, expected): + grid = Grid((4, 5)) + x, y = grid.dimensions + + s = Scalar(name='s') # noqa + fa = Function(name='fa', grid=grid, dimensions=(x,), shape=(4,)) # noqa + + assert str(collect_nested(eval(expr))) == expected + + @pytest.mark.parametrize('expr,expected,estimate', [ ('Eq(t0, 3)', 0, False), ('Eq(t0, 4.5)', 0, False), From eafda08dc304e85b47d9b65b994c7136615d3d63 Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 7 Sep 2023 08:21:13 -0400 Subject: [PATCH 3/5] api: reconstruct sparse with subfunc rather than its data --- devito/types/sparse.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/devito/types/sparse.py b/devito/types/sparse.py index 4244c001ab..f036a68c9c 100644 --- a/devito/types/sparse.py +++ b/devito/types/sparse.py @@ -791,7 +791,7 @@ class SparseFunction(AbstractSparseFunction): _sub_functions = ('coordinates',) - __rkwargs__ = AbstractSparseFunction.__rkwargs__ + ('coordinates_data',) + __rkwargs__ = AbstractSparseFunction.__rkwargs__ + ('coordinates',) def __init_finalize__(self, *args, **kwargs): super().__init_finalize__(*args, **kwargs) @@ -1014,8 +1014,8 @@ class PrecomputedSparseFunction(AbstractSparseFunction): _sub_functions = ('gridpoints', 'coordinates', 'interpolation_coeffs') __rkwargs__ = (AbstractSparseFunction.__rkwargs__ + - ('r', 'gridpoints_data', 'coordinates_data', - 'interpolation_coeffs_data')) + ('r', 'gridpoints', 'coordinates', + 'interpolation_coeffs')) def __init_finalize__(self, *args, **kwargs): super().__init_finalize__(*args, **kwargs) From cd59050c325388d7343dfc2fcea95be21481fbcf Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 7 Sep 2023 08:21:48 -0400 Subject: [PATCH 4/5] compiler: remove mul print tweak --- devito/symbolics/printer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 8f7ef6a719..c47ef95bfc 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -105,10 +105,6 @@ def _print_Mod(self, expr): args = ['(%s)' % self._print(a) for a in expr.args] return '%'.join(args) - def _print_Mul(self, expr): - term = super()._print_Mul(expr) - return term.replace("(-1)*", "-") - def _print_Min(self, expr): if has_integer_args(*expr.args) and len(expr.args) == 2: return "MIN(%s)" % self._print(expr.args)[1:-1] From 68c29d47a829cbc751a6b54c08ef2d664e92eb5a Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 7 Sep 2023 09:53:00 -0400 Subject: [PATCH 5/5] api: fix subfunction handling (subs/rebuild/...) --- devito/symbolics/printer.py | 4 + devito/types/dense.py | 3 +- devito/types/sparse.py | 113 ++++++++---------- examples/seismic/inversion/inversion_utils.py | 22 ++-- tests/test_interpolation.py | 4 +- tests/test_pickle.py | 2 +- tests/{test_msparse.py => test_sparse.py} | 64 +++++++++- 7 files changed, 128 insertions(+), 84 deletions(-) rename tests/{test_msparse.py => test_sparse.py} (84%) diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index c47ef95bfc..8f7ef6a719 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -105,6 +105,10 @@ def _print_Mod(self, expr): args = ['(%s)' % self._print(a) for a in expr.args] return '%'.join(args) + def _print_Mul(self, expr): + term = super()._print_Mul(expr) + return term.replace("(-1)*", "-") + def _print_Min(self, expr): if has_integer_args(*expr.args) and len(expr.args) == 2: return "MIN(%s)" % self._print(expr.args)[1:-1] diff --git a/devito/types/dense.py b/devito/types/dense.py index 4c912d2704..ec371b662c 100644 --- a/devito/types/dense.py +++ b/devito/types/dense.py @@ -20,7 +20,7 @@ from devito.finite_differences import Differentiable, generate_fd_shortcuts from devito.tools import (ReducerMap, as_tuple, c_restrict_void_p, flatten, is_integer, memoized_meth, dtype_to_ctype, humanbytes) -from devito.types.dimension import Dimension, DynamicDimension +from devito.types.dimension import Dimension from devito.types.args import ArgProvider from devito.types.caching import CacheManager from devito.types.basic import AbstractFunction, Size @@ -1040,6 +1040,7 @@ def __indices_setup__(cls, *args, **kwargs): dimensions = grid.dimensions if args: + assert len(args) == len(dimensions) return tuple(dimensions), tuple(args) # Staggered indices diff --git a/devito/types/sparse.py b/devito/types/sparse.py index f036a68c9c..42ecdc4a9a 100644 --- a/devito/types/sparse.py +++ b/devito/types/sparse.py @@ -61,11 +61,9 @@ def __indices_setup__(cls, *args, **kwargs): dimensions = (Dimension(name='p_%s' % kwargs["name"]),) if args: - indices = args + return tuple(dimensions), tuple(args) else: - indices = dimensions - - return dimensions, indices + return dimensions, dimensions @classmethod def __shape_setup__(cls, **kwargs): @@ -80,16 +78,6 @@ def __shape_setup__(cls, **kwargs): shape = (glb_npoint[grid.distributor.myrank],) return shape - def func(self, *args, **kwargs): - # Rebuild subfunctions first to avoid new data creation as we have to use `_data` - # as a reconstruction kwargs to avoid the circular dependency - # with the parent in SubFunction - # This is also necessary to avoid shape issue in the SubFunction with mpi - for s in self._sub_functions: - if getattr(self, s) is not None: - kwargs.update({s: getattr(self, s).func(*args, **kwargs)}) - return super().func(*args, **kwargs) - def __fd_setup__(self): """ Dynamically add derivative short-cuts. @@ -108,24 +96,39 @@ def __distributor_setup__(self, **kwargs): ) def __subfunc_setup__(self, key, suffix, dtype=None): + # Shape and dimensions from args + name = '%s_%s' % (self.name, suffix) + + if key is not None and not isinstance(key, SubFunction): + key = np.array(key) + + if key is not None: + dimensions = (self._sparse_dim, Dimension(name='d')) + if key.ndim > 2: + dimensions = (self._sparse_dim, Dimension(name='d'), + *mkdims("i", n=key.ndim-2)) + else: + dimensions = (self._sparse_dim, Dimension(name='d')) + shape = (self.npoint, self.grid.dim, *key.shape[2:]) + else: + dimensions = (self._sparse_dim, Dimension(name='d')) + shape = (self.npoint, self.grid.dim) + + # Check if already a SubFunction if isinstance(key, SubFunction): - return key + # Need to rebuild so the dimensions match the parent SparseFunction + indices = (self.indices[self._sparse_position], *key.indices[1:]) + return key._rebuild(*indices, name=name, shape=shape, + alias=self.alias, halo=None) elif key is not None and not isinstance(key, Iterable): raise ValueError("`%s` must be either SubFunction " "or iterable (e.g., list, np.ndarray)" % key) - name = '%s_%s' % (self.name, suffix) - dimensions = (self._sparse_dim, Dimension(name='d')) - shape = (self.npoint, self.grid.dim) - if key is None: # Fallback to default behaviour dtype = dtype or self.dtype else: - if key is not None: - key = np.array(key) - - if (shape != key.shape[:2] and key.shape != (shape[1],)) and \ + if (shape != key.shape and key.shape != (shape[1],)) and \ self._distributor.nprocs == 1: raise ValueError("Incompatible shape for %s, `%s`; expected `%s`" % (suffix, key.shape[:2], shape)) @@ -136,12 +139,8 @@ def __subfunc_setup__(self, key, suffix, dtype=None): else: dtype = dtype or self.dtype - if key is not None and key.ndim > 2: - shape = (*shape, *key.shape[2:]) - dimensions = (*dimensions, *mkdims("i", n=key.ndim-2)) - sf = SubFunction( - name=name, parent=self, dtype=dtype, dimensions=dimensions, + name=name, dtype=dtype, dimensions=dimensions, shape=shape, space_order=0, initializer=key, alias=self.alias, distributor=self._distributor ) @@ -657,20 +656,6 @@ def time_dim(self): """The time Dimension.""" return self._time_dim - @classmethod - def __indices_setup__(cls, *args, **kwargs): - dimensions = as_tuple(kwargs.get('dimensions')) - if not dimensions: - dimensions = (kwargs['grid'].time_dim, - Dimension(name='p_%s' % kwargs["name"])) - - if args: - indices = args - else: - indices = dimensions - - return dimensions, indices - @classmethod def __shape_setup__(cls, **kwargs): shape = kwargs.get('shape') @@ -686,6 +671,18 @@ def __shape_setup__(cls, **kwargs): return tuple(shape) + @classmethod + def __indices_setup__(cls, *args, **kwargs): + dimensions = as_tuple(kwargs.get('dimensions')) + if not dimensions: + dimensions = (kwargs['grid'].time_dim, + Dimension(name='p_%s' % kwargs["name"]),) + + if args: + return tuple(dimensions), tuple(args) + else: + return dimensions, dimensions + @property def nt(self): return self.shape[self._time_position] @@ -1032,13 +1029,14 @@ def __init_finalize__(self, *args, **kwargs): if r <= 0: raise ValueError('`r` must be > 0') # Make sure radius matches the coefficients size - nr = interpolation_coeffs.shape[-1] - if nr // 2 != r: - if nr == r: - r = r // 2 - else: - raise ValueError("Interpolation coefficients shape %d do " - "not match specified radius %d" % (r, nr)) + if interpolation_coeffs is not None: + nr = interpolation_coeffs.shape[-1] + if nr // 2 != r: + if nr == r: + r = r // 2 + else: + raise ValueError("Interpolation coefficients shape %d do " + "not match specified radius %d" % (r, nr)) self._radius = r if coordinates is not None and gridpoints is not None: @@ -1680,23 +1678,6 @@ def inject(self, field, expr, u_t=None, p_t=None): return out - @classmethod - def __indices_setup__(cls, *args, **kwargs): - """ - Return the default Dimension indices for a given data shape. - """ - dimensions = kwargs.get('dimensions') - if dimensions is None: - dimensions = (kwargs['grid'].time_dim, Dimension( - name='p_%s' % kwargs["name"])) - - if args: - indices = args - else: - indices = dimensions - - return dimensions, indices - @classmethod def __shape_setup__(cls, **kwargs): # This happens before __init__, so we have to get 'npoint' diff --git a/examples/seismic/inversion/inversion_utils.py b/examples/seismic/inversion/inversion_utils.py index 7f4784ae8c..9f5709b315 100644 --- a/examples/seismic/inversion/inversion_utils.py +++ b/examples/seismic/inversion/inversion_utils.py @@ -7,19 +7,15 @@ def compute_residual(res, dobs, dsyn): """ Computes the data residual dsyn - dobs into residual """ - if res.grid.distributor.is_parallel: - # If we run with MPI, we have to compute the residual via an operator - # First make sure we can take the difference and that receivers are at the - # same position - assert np.allclose(dobs.coordinates.data[:], dsyn.coordinates.data) - assert np.allclose(res.coordinates.data[:], dsyn.coordinates.data) - # Create a difference operator - diff_eq = Eq(res, dsyn.subs({dsyn.dimensions[-1]: res.dimensions[-1]}) - - dobs.subs({dobs.dimensions[-1]: res.dimensions[-1]})) - Operator(diff_eq)() - else: - # A simple data difference is enough in serial - res.data[:] = dsyn.data[:] - dobs.data[:] + # If we run with MPI, we have to compute the residual via an operator + # First make sure we can take the difference and that receivers are at the + # same position + assert np.allclose(dobs.coordinates.data[:], dsyn.coordinates.data) + assert np.allclose(res.coordinates.data[:], dsyn.coordinates.data) + # Create a difference operator + diff_eq = Eq(res, dsyn.subs({dsyn.dimensions[-1]: res.dimensions[-1]}) - + dobs.subs({dobs.dimensions[-1]: res.dimensions[-1]})) + Operator(diff_eq)() return res diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 9a0608454a..3a22ca1db7 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -242,8 +242,8 @@ def test_precomputed_injection_time(r): sf = PrecomputedSparseTimeFunction(name='s', grid=m.grid, r=r, npoint=len(coords), gridpoints=gridpoints, nt=nt, interpolation_coeffs=interpolation_coeffs) - - expr = sf.inject(m, Float(1.)) + sf.data.fill(1.) + expr = sf.inject(m, sf) op = Operator(expr) diff --git a/tests/test_pickle.py b/tests/test_pickle.py index 62423b2c15..16f44bdaed 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -111,7 +111,7 @@ def test_precomputed_sparse_function(self, mode, pickle): sf = PrecomputedSparseTimeFunction( name='sf', grid=grid, r=2, npoint=3, nt=5, - interpolation_coeffs=np.ndarray(shape=(3, 2, 2)), **kw + interpolation_coeffs=np.random.randn(3, 2, 2), **kw ) sf.data[2, 1] = 5. diff --git a/tests/test_msparse.py b/tests/test_sparse.py similarity index 84% rename from tests/test_msparse.py rename to tests/test_sparse.py index 5cbfde848a..04545d7cc6 100644 --- a/tests/test_msparse.py +++ b/tests/test_sparse.py @@ -4,7 +4,13 @@ import numpy as np import scipy.sparse -from devito import Grid, TimeFunction, Eq, Operator, MatrixSparseTimeFunction +from devito import Grid, TimeFunction, Eq, Operator, Dimension +from devito import (SparseFunction, SparseTimeFunction, PrecomputedSparseFunction, + PrecomputedSparseTimeFunction, MatrixSparseTimeFunction) + + +_sptypes = [SparseFunction, SparseTimeFunction, + PrecomputedSparseFunction, PrecomputedSparseTimeFunction] class TestMatrixSparseTimeFunction(object): @@ -394,5 +400,61 @@ def test_mpi(self): assert sf.data[0, 0] == -3.0 # 1 * (1 * 1) * 1 + (-1) * (2 * 2) * 1 +class TestSparseFunction(object): + + @pytest.mark.parametrize('sptype', _sptypes) + def test_rebuild(self, sptype): + grid = Grid((3, 3, 3)) + # Base object + sp = sptype(name="s", grid=grid, npoint=1, nt=11, r=2, + interpolation_coeffs=np.random.randn(1, 3, 2), + coordinates=np.random.randn(1, 3)) + + # Check subfunction setup + for subf in sp._sub_functions: + if getattr(sp, subf) is not None: + assert getattr(sp, subf).name.startswith("s_") + + # Rebuild with different name, this should drop the function + # and create new data + sp2 = sp._rebuild(name="sr") + + # Check new subfunction + for subf in sp2._sub_functions: + if getattr(sp2, subf) is not None: + assert getattr(sp2, subf).name.startswith("sr_") + assert np.all(getattr(sp2, subf).data == 0) + + # Rebuild with different name as an alias + sp2 = sp._rebuild(name="sr2", alias=True) + for subf in sp2._sub_functions: + if getattr(sp2, subf) is not None: + assert getattr(sp2, subf).name.startswith("sr2_") + assert getattr(sp2, subf).data is None + + @pytest.mark.parametrize('sptype', _sptypes) + def test_subs(self, sptype): + grid = Grid((3, 3, 3)) + # Base object + sp = sptype(name="s", grid=grid, npoint=1, nt=11, r=2, + interpolation_coeffs=np.random.randn(1, 3, 2), + coordinates=np.random.randn(1, 3)) + + # Check subfunction setup + for subf in sp._sub_functions: + if getattr(sp, subf) is not None: + assert getattr(sp, subf).dimensions[0] == sp._sparse_dim + + # Do substitution on sparse dimension + new_spdim = Dimension(name="newsp") + + sps = sp._subs(sp._sparse_dim, new_spdim) + assert sps.indices[sp._sparse_position] == new_spdim + for subf in sps._sub_functions: + if getattr(sps, subf) is not None: + assert getattr(sps, subf).indices[0] == new_spdim + assert np.all(getattr(sps, subf).data == getattr(sp, subf).data) + + if __name__ == "__main__": TestMatrixSparseTimeFunction().test_mpi()