Skip to content

Commit

Permalink
tests: add non subfunc test
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Oct 25, 2024
1 parent a6172d5 commit 32f9b30
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 16 deletions.
35 changes: 21 additions & 14 deletions devito/types/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,21 @@ def __distributor_setup__(self, **kwargs):
def __subfunc_setup__(self, suffix, keys, dtype=None, inkwargs=False, **kwargs):
key = None
for k in keys:
if k in kwargs:
if kwargs[k] is not None:
key = kwargs[k]
break
else:
# In cases such as rebuild,
# the subfunction may be passed explicitly as None
return None
if k not in kwargs:
continue
elif kwargs[k] is None:
# In cases such as rebuild,
# the subfunction may be passed explicitly as None
return None
else:
key = kwargs[k]
break
else:
if inkwargs:
# Only create the subfunction if provided.
# Only create the subfunction if provided. This is useful
# with PrecomputedSparseFunctions that can have different subfunctions
# to skip creating extra if another one has already
# been provided
return None

# Shape and dimensions from args
Expand Down Expand Up @@ -619,10 +623,14 @@ def _dist_subfunc_gather(self, sfuncd, subfunc):
# `_dist_scatter` is here sent.

def _dist_scatter(self, alias=None, data=None):
key = alias or self
mapper = {self: self._dist_data_scatter(data=data)}
for i in self._sub_functions:
if getattr(alias, i) is not None:
mapper.update(self._dist_subfunc_scatter(getattr(self, i)))
if getattr(key, i) is not None:
# Pick up alias' in case runtime SparseFunctions is missing
# a subfunction
sf = getattr(self, i) or getattr(key, i)
mapper.update(self._dist_subfunc_scatter(sf))
return mapper

def _eval_at(self, func):
Expand Down Expand Up @@ -662,7 +670,7 @@ def _arg_values(self, **kwargs):
else:
# We've been provided a pure-data replacement (array)
values = {}
for k, v in self._dist_scatter(new).items():
for k, v in self._dist_scatter(data=new).items():
values[k.name] = v
for i, s in zip(k.indices, v.shape):
size = s - sum(k._size_nodomain[i])
Expand Down Expand Up @@ -859,7 +867,7 @@ def __init_finalize__(self, *args, **kwargs):
super().__init_finalize__(*args, **kwargs)

# Set up sparse point coordinates
keys = ('coordinates','coordinates_data')
keys = ('coordinates', 'coordinates_data')
self._coordinates = self.__subfunc_setup__('coords', keys, **kwargs)
self._dist_origin = {self._coordinates: self.grid.origin_offset}

Expand Down Expand Up @@ -1113,7 +1121,6 @@ def __init_finalize__(self, *args, **kwargs):

if not any(k in kwargs for k in ('coordinates', 'gridpoints',
'coordinates_data', 'gridpoints_data')):
print(kwargs)
raise ValueError("PrecomputedSparseFunction requires `coordinates`"
"or `gridpoints` arguments")

Expand Down
20 changes: 18 additions & 2 deletions tests/test_rebuild.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import numpy as np
import pytest

from devito import Dimension, Function
from devito.types import StencilDimension
from devito import Dimension, Function, Grid
from devito.types import StencilDimension, SparseFunction, PrecomputedSparseFunction
from devito.data.allocators import DataReference


Expand Down Expand Up @@ -65,3 +65,19 @@ def test_stencil_dimension_borked(self):

# TODO: Look into Symbol._cache_key and the way the key is generated
assert sd0 is sd1


class TestSparseFunction:

@pytest.mark.parametrize('sfunc', [SparseFunction, PrecomputedSparseFunction])
def test_none_subfunc(self, sfunc):
grid = Grid((4, 4))
coords = np.zeros((5, 2))

s = sfunc(name='s', grid=grid, npoint=5, coordinates=coords, r=1)

assert s.coordinates is not None

# Explicity set coordinates to None
sr = s._rebuild(function=None, initializer=None, coordinates=None)
assert sr.coordinates is None

0 comments on commit 32f9b30

Please sign in to comment.