diff --git a/devito/types/sparse.py b/devito/types/sparse.py index f7f40fcaea..dc6c225695 100644 --- a/devito/types/sparse.py +++ b/devito/types/sparse.py @@ -135,17 +135,21 @@ def __distributor_setup__(self, **kwargs): def __subfunc_setup__(self, suffix, keys, dtype=None, inkwargs=False, **kwargs): key = None for k in keys: - if k in kwargs: - if kwargs[k] is not None: - key = kwargs[k] - break - else: - # In cases such as rebuild, - # the subfunction may be passed explicitly as None - return None + if k not in kwargs: + continue + elif kwargs[k] is None: + # In cases such as rebuild, + # the subfunction may be passed explicitly as None + return None + else: + key = kwargs[k] + break else: if inkwargs: - # Only create the subfunction if provided. + # Only create the subfunction if provided. This is useful + # with PrecomputedSparseFunctions that can have different subfunctions + # to skip creating extra if another one has already + # been provided return None # Shape and dimensions from args @@ -619,10 +623,14 @@ def _dist_subfunc_gather(self, sfuncd, subfunc): # `_dist_scatter` is here sent. def _dist_scatter(self, alias=None, data=None): + key = alias or self mapper = {self: self._dist_data_scatter(data=data)} for i in self._sub_functions: - if getattr(alias, i) is not None: - mapper.update(self._dist_subfunc_scatter(getattr(self, i))) + if getattr(key, i) is not None: + # Pick up alias' in case runtime SparseFunctions is missing + # a subfunction + sf = getattr(self, i) or getattr(key, i) + mapper.update(self._dist_subfunc_scatter(sf)) return mapper def _eval_at(self, func): @@ -662,7 +670,7 @@ def _arg_values(self, **kwargs): else: # We've been provided a pure-data replacement (array) values = {} - for k, v in self._dist_scatter(new).items(): + for k, v in self._dist_scatter(data=new).items(): values[k.name] = v for i, s in zip(k.indices, v.shape): size = s - sum(k._size_nodomain[i]) @@ -859,7 +867,7 @@ def __init_finalize__(self, *args, **kwargs): super().__init_finalize__(*args, **kwargs) # Set up sparse point coordinates - keys = ('coordinates','coordinates_data') + keys = ('coordinates', 'coordinates_data') self._coordinates = self.__subfunc_setup__('coords', keys, **kwargs) self._dist_origin = {self._coordinates: self.grid.origin_offset} @@ -1113,7 +1121,6 @@ def __init_finalize__(self, *args, **kwargs): if not any(k in kwargs for k in ('coordinates', 'gridpoints', 'coordinates_data', 'gridpoints_data')): - print(kwargs) raise ValueError("PrecomputedSparseFunction requires `coordinates`" "or `gridpoints` arguments") @@ -1135,9 +1142,9 @@ def __init_finalize__(self, *args, **kwargs): self._dist_origin.update({self._gridpoints: self.grid.origin_ioffset}) # Setup the interpolation coefficients. These are compulsory - keys = ('interpolation_coeffs', 'interpolation_coeffs_data') + ckeys = ('interpolation_coeffs', 'interpolation_coeffs_data') self._interpolation_coeffs = \ - self.__subfunc_setup__('interp_coeffs', keys, dtype=dtype, **kwargs) + self.__subfunc_setup__('interp_coeffs', ckeys, dtype=dtype, **kwargs) # Grid points per sparse point (2 in the case of bilinear and trilinear) r = kwargs.get('r') @@ -1146,7 +1153,7 @@ def __init_finalize__(self, *args, **kwargs): if r <= 0: raise ValueError('`r` must be > 0') # Make sure radius matches the coefficients size - if self._interpolation_coeffs is not None: + if any(c in kwargs for c in ckeys): nr = self._interpolation_coeffs.shape[-1] if nr // 2 != r: if nr == r: @@ -2148,7 +2155,7 @@ def manual_scatter(self, *, data_all_zero=False): **self._build_par_dim_to_nnz(scattered_gp, active_mrow), } - def _dist_scatter(self, data=None): + def _dist_scatter(self, alias=None, data=None): assert data is None if self.scatter_result is None: raise Exception("_dist_scatter called before manual_scatter called") diff --git a/tests/test_rebuild.py b/tests/test_rebuild.py index 5498d819a7..79b2765a30 100644 --- a/tests/test_rebuild.py +++ b/tests/test_rebuild.py @@ -1,8 +1,8 @@ import numpy as np import pytest -from devito import Dimension, Function -from devito.types import StencilDimension +from devito import Dimension, Function, Grid +from devito.types import StencilDimension, SparseFunction, PrecomputedSparseFunction from devito.data.allocators import DataReference @@ -65,3 +65,19 @@ def test_stencil_dimension_borked(self): # TODO: Look into Symbol._cache_key and the way the key is generated assert sd0 is sd1 + + +class TestSparseFunction: + + @pytest.mark.parametrize('sfunc', [SparseFunction, PrecomputedSparseFunction]) + def test_none_subfunc(self, sfunc): + grid = Grid((4, 4)) + coords = np.zeros((5, 2)) + + s = sfunc(name='s', grid=grid, npoint=5, coordinates=coords, r=1) + + assert s.coordinates is not None + + # Explicity set coordinates to None + sr = s._rebuild(function=None, initializer=None, coordinates=None) + assert sr.coordinates is None