Skip to content

Commit

Permalink
tests: add non subfunc test
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Oct 25, 2024
1 parent a6172d5 commit 898dd73
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 20 deletions.
43 changes: 25 additions & 18 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,21 @@ def __distributor_setup__(self, **kwargs):
def __subfunc_setup__(self, suffix, keys, dtype=None, inkwargs=False, **kwargs):
key = None
for k in keys:
if k in kwargs:
if kwargs[k] is not None:
key = kwargs[k]
break
else:
# In cases such as rebuild,
# the subfunction may be passed explicitly as None
return None
if k not in kwargs:
continue
elif kwargs[k] is None:
# In cases such as rebuild,
# the subfunction may be passed explicitly as None
return None
else:
key = kwargs[k]
break
else:
if inkwargs:
# Only create the subfunction if provided.
# Only create the subfunction if provided. This is useful
# with PrecomputedSparseFunctions that can have different subfunctions
# to skip creating extra if another one has already
# been provided
return None

# Shape and dimensions from args
Expand Down Expand Up @@ -619,10 +623,14 @@ def _dist_subfunc_gather(self, sfuncd, subfunc):
# `_dist_scatter` is here sent.

def _dist_scatter(self, alias=None, data=None):
key = alias or self
mapper = {self: self._dist_data_scatter(data=data)}
for i in self._sub_functions:
if getattr(alias, i) is not None:
mapper.update(self._dist_subfunc_scatter(getattr(self, i)))
if getattr(key, i) is not None:
# Pick up alias' in case runtime SparseFunctions is missing
# a subfunction
sf = getattr(self, i) or getattr(key, i)
mapper.update(self._dist_subfunc_scatter(sf))
return mapper

def _eval_at(self, func):
Expand Down Expand Up @@ -662,7 +670,7 @@ def _arg_values(self, **kwargs):
else:
# We've been provided a pure-data replacement (array)
values = {}
for k, v in self._dist_scatter(new).items():
for k, v in self._dist_scatter(data=new).items():
values[k.name] = v
for i, s in zip(k.indices, v.shape):
size = s - sum(k._size_nodomain[i])
Expand Down Expand Up @@ -859,7 +867,7 @@ def __init_finalize__(self, *args, **kwargs):
super().__init_finalize__(*args, **kwargs)

# Set up sparse point coordinates
keys = ('coordinates','coordinates_data')
keys = ('coordinates', 'coordinates_data')
self._coordinates = self.__subfunc_setup__('coords', keys, **kwargs)
self._dist_origin = {self._coordinates: self.grid.origin_offset}

Expand Down Expand Up @@ -1113,7 +1121,6 @@ def __init_finalize__(self, *args, **kwargs):

if not any(k in kwargs for k in ('coordinates', 'gridpoints',
'coordinates_data', 'gridpoints_data')):
print(kwargs)
raise ValueError("PrecomputedSparseFunction requires `coordinates`"
"or `gridpoints` arguments")

Expand All @@ -1135,9 +1142,9 @@ def __init_finalize__(self, *args, **kwargs):
self._dist_origin.update({self._gridpoints: self.grid.origin_ioffset})

# Setup the interpolation coefficients. These are compulsory
keys = ('interpolation_coeffs', 'interpolation_coeffs_data')
ckeys = ('interpolation_coeffs', 'interpolation_coeffs_data')
self._interpolation_coeffs = \
self.__subfunc_setup__('interp_coeffs', keys, dtype=dtype, **kwargs)
self.__subfunc_setup__('interp_coeffs', ckeys, dtype=dtype, **kwargs)

# Grid points per sparse point (2 in the case of bilinear and trilinear)
r = kwargs.get('r')
Expand All @@ -1146,7 +1153,7 @@ def __init_finalize__(self, *args, **kwargs):
if r <= 0:
raise ValueError('`r` must be > 0')
# Make sure radius matches the coefficients size
if self._interpolation_coeffs is not None:
if any(c in kwargs for c in ckeys):
nr = self._interpolation_coeffs.shape[-1]
if nr // 2 != r:
if nr == r:
Expand Down Expand Up @@ -2148,7 +2155,7 @@ def manual_scatter(self, *, data_all_zero=False):
**self._build_par_dim_to_nnz(scattered_gp, active_mrow),
}

def _dist_scatter(self, data=None):
def _dist_scatter(self, alias=None, data=None):
assert data is None
if self.scatter_result is None:
raise Exception("_dist_scatter called before manual_scatter called")
Expand Down
20 changes: 18 additions & 2 deletions tests/test_rebuild.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
import pytest

from devito import Dimension, Function
from devito.types import StencilDimension
from devito import Dimension, Function, Grid
from devito.types import StencilDimension, SparseFunction, PrecomputedSparseFunction
from devito.data.allocators import DataReference


Expand Down Expand Up @@ -65,3 +65,19 @@ def test_stencil_dimension_borked(self):

# TODO: Look into Symbol._cache_key and the way the key is generated
assert sd0 is sd1


class TestSparseFunction:

@pytest.mark.parametrize('sfunc', [SparseFunction, PrecomputedSparseFunction])
def test_none_subfunc(self, sfunc):
grid = Grid((4, 4))
coords = np.zeros((5, 2))

s = sfunc(name='s', grid=grid, npoint=5, coordinates=coords, r=1)

assert s.coordinates is not None

# Explicity set coordinates to None
sr = s._rebuild(function=None, initializer=None, coordinates=None)
assert sr.coordinates is None

0 comments on commit 898dd73

Please sign in to comment.