From f3bfe47728dff9910b5cb5636b0c4ff04665901e Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 12 Oct 2023 11:19:53 -0400 Subject: [PATCH] api: enforce interpolation radius to be smaller than any input space order --- devito/operations/interpolators.py | 17 ++++++++++++++++- tests/test_interpolation.py | 13 +++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/devito/operations/interpolators.py b/devito/operations/interpolators.py index 3d2dcb74660..10d5ce4fbc9 100644 --- a/devito/operations/interpolators.py +++ b/devito/operations/interpolators.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod +from functools import wraps import sympy from cached_property import cached_property from devito.finite_differences.differentiable import Mul from devito.finite_differences.elementary import floor -from devito.symbolics import retrieve_function_carriers, INT +from devito.symbolics import retrieve_function_carriers, retrieve_functions, INT from devito.tools import as_tuple, flatten from devito.types import (ConditionalDimension, Eq, Inc, Evaluable, Symbol, CustomDimension) @@ -14,6 +15,18 @@ __all__ = ['LinearInterpolator', 'PrecomputedInterpolator'] +def check_radius(func): + @wraps(func) + def wrapper(interp, *args, **kwargs): + r = interp.sfunction.r + funcs = set(retrieve_functions(args)) - {interp.sfunction} + so = min({f.space_order for f in funcs}) + if so < r: + raise ValueError("Space order %d smaller than interpolation r %d" % (so, r)) + return func(interp, *args, **kwargs) + return wrapper + + class UnevaluatedSparseOperation(sympy.Expr, Evaluable): """ @@ -209,6 +222,7 @@ def _interp_idx(self, variables, implicit_dims=None): return idx_subs, temps + @check_radius def interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None): """ Generate equations interpolating an arbitrary expression into ``self``. @@ -226,6 +240,7 @@ def interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None): """ return Interpolation(expr, increment, implicit_dims, self_subs, self) + @check_radius def inject(self, field, expr, implicit_dims=None): """ Generate equations injecting an arbitrary expression into a field. diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 97d86c1759f..7d07c39c781 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -761,3 +761,16 @@ def test_inject_function(): for i in [0, 1, 3, 4]: for j in [0, 1, 3, 4]: assert u.data[1, i, j] == 0 + + +def test_interpolation_radius(): + nt = 11 + + grid = Grid(shape=(5, 5)) + u = TimeFunction(name="u", grid=grid, space_order=0) + src = SparseTimeFunction(name="src", grid=grid, nt=nt, npoint=1) + try: + src.interpolate(u) + assert False + except ValueError: + assert True