Skip to content

Commit

Permalink
tests: add non subfunc test
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Oct 25, 2024
1 parent a6172d5 commit c3da175
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
9 changes: 6 additions & 3 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,10 @@ def _dist_scatter(self, alias=None, data=None):
mapper = {self: self._dist_data_scatter(data=data)}
for i in self._sub_functions:
if getattr(alias, i) is not None:
mapper.update(self._dist_subfunc_scatter(getattr(self, i)))
# Pick up alias' in case runtime SparseFunctions is missing
# a subfunction
sf = getattr(self, i) or getattr(alias, i)
mapper.update(self._dist_subfunc_scatter(sf))
return mapper

def _eval_at(self, func):
Expand Down Expand Up @@ -662,7 +665,7 @@ def _arg_values(self, **kwargs):
else:
# We've been provided a pure-data replacement (array)
values = {}
for k, v in self._dist_scatter(new).items():
for k, v in self._dist_scatter(data=new).items():
values[k.name] = v
for i, s in zip(k.indices, v.shape):
size = s - sum(k._size_nodomain[i])
Expand Down Expand Up @@ -859,7 +862,7 @@ def __init_finalize__(self, *args, **kwargs):
super().__init_finalize__(*args, **kwargs)

# Set up sparse point coordinates
keys = ('coordinates','coordinates_data')
keys = ('coordinates', 'coordinates_data')
self._coordinates = self.__subfunc_setup__('coords', keys, **kwargs)
self._dist_origin = {self._coordinates: self.grid.origin_offset}

Expand Down
20 changes: 18 additions & 2 deletions tests/test_rebuild.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
import pytest

from devito import Dimension, Function
from devito.types import StencilDimension
from devito import Dimension, Function, Grid
from devito.types import StencilDimension, SparseFunction, PrecomputedSparseFunction
from devito.data.allocators import DataReference


Expand Down Expand Up @@ -65,3 +65,19 @@ def test_stencil_dimension_borked(self):

# TODO: Look into Symbol._cache_key and the way the key is generated
assert sd0 is sd1


class TestSparseFunction:

@pytest.mark.parametrize('sfunc', [SparseFunction, PrecomputedSparseFunction])
def test_none_subfunc(self, sfunc):
grid = Grid((4, 4))
coords = np.zeros((5, 2))

s = sfunc(name='s', grid=grid, npoint=5, coordinates=coords, r=1)

assert s.coordinates is not None

# Explicity set coordinates to None
sr = s._rebuild(function=None, initializer=None, coordinates=None)
assert sr.coordinates is None

0 comments on commit c3da175

Please sign in to comment.