Skip to content

Commit

Permalink
api: enforce interpolation radius to be smaller than any input space …
Browse files Browse the repository at this point in the history
…order
  • Loading branch information
mloubout committed Oct 12, 2023
1 parent 617ce82 commit f3bfe47
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
17 changes: 16 additions & 1 deletion devito/operations/interpolators.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from abc import ABC, abstractmethod
from functools import wraps

import sympy
from cached_property import cached_property

from devito.finite_differences.differentiable import Mul
from devito.finite_differences.elementary import floor
from devito.symbolics import retrieve_function_carriers, INT
from devito.symbolics import retrieve_function_carriers, retrieve_functions, INT
from devito.tools import as_tuple, flatten
from devito.types import (ConditionalDimension, Eq, Inc, Evaluable, Symbol,
CustomDimension)
Expand All @@ -14,6 +15,18 @@
__all__ = ['LinearInterpolator', 'PrecomputedInterpolator']


def check_radius(func):
@wraps(func)
def wrapper(interp, *args, **kwargs):
r = interp.sfunction.r
funcs = set(retrieve_functions(args)) - {interp.sfunction}
so = min({f.space_order for f in funcs})
if so < r:
raise ValueError("Space order %d smaller than interpolation r %d" % (so, r))
return func(interp, *args, **kwargs)
return wrapper


class UnevaluatedSparseOperation(sympy.Expr, Evaluable):

"""
Expand Down Expand Up @@ -209,6 +222,7 @@ def _interp_idx(self, variables, implicit_dims=None):

return idx_subs, temps

@check_radius
def interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):
"""
Generate equations interpolating an arbitrary expression into ``self``.
Expand All @@ -226,6 +240,7 @@ def interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None):
"""
return Interpolation(expr, increment, implicit_dims, self_subs, self)

@check_radius
def inject(self, field, expr, implicit_dims=None):
"""
Generate equations injecting an arbitrary expression into a field.
Expand Down
13 changes: 13 additions & 0 deletions tests/test_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,3 +761,16 @@ def test_inject_function():
for i in [0, 1, 3, 4]:
for j in [0, 1, 3, 4]:
assert u.data[1, i, j] == 0


def test_interpolation_radius():
nt = 11

grid = Grid(shape=(5, 5))
u = TimeFunction(name="u", grid=grid, space_order=0)
src = SparseTimeFunction(name="src", grid=grid, nt=nt, npoint=1)
try:
src.interpolate(u)
assert False
except ValueError:
assert True

0 comments on commit f3bfe47

Please sign in to comment.