From aabb58117a16cef55998596b7bad754079d2e1de Mon Sep 17 00:00:00 2001 From: mloubout Date: Thu, 12 Oct 2023 11:19:53 -0400 Subject: [PATCH] api: enforce interpolation radius to be smaller than any input space order --- devito/operations/interpolators.py | 17 ++++++++++++++++- devito/types/grid.py | 2 +- tests/test_dle.py | 2 +- tests/test_interpolation.py | 22 +++++++++++++++++----- tests/test_mpi.py | 6 +++--- tests/test_operator.py | 10 +++++----- 6 files changed, 43 insertions(+), 16 deletions(-) diff --git a/devito/operations/interpolators.py b/devito/operations/interpolators.py index 3d2dcb74660..75536a6550b 100644 --- a/devito/operations/interpolators.py +++ b/devito/operations/interpolators.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod +from functools import wraps import sympy from cached_property import cached_property from devito.finite_differences.differentiable import Mul from devito.finite_differences.elementary import floor -from devito.symbolics import retrieve_function_carriers, INT +from devito.symbolics import retrieve_function_carriers, retrieve_functions, INT from devito.tools import as_tuple, flatten from devito.types import (ConditionalDimension, Eq, Inc, Evaluable, Symbol, CustomDimension) @@ -14,6 +15,18 @@ __all__ = ['LinearInterpolator', 'PrecomputedInterpolator'] +def check_radius(func): + @wraps(func) + def wrapper(interp, *args, **kwargs): + r = interp.sfunction.r + funcs = set(retrieve_functions(args)) - {interp.sfunction} + so = min({f.space_order for f in funcs} or {r}) + if so < r: + raise ValueError("Space order %d smaller than interpolation r %d" % (so, r)) + return func(interp, *args, **kwargs) + return wrapper + + class UnevaluatedSparseOperation(sympy.Expr, Evaluable): """ @@ -209,6 +222,7 @@ def _interp_idx(self, variables, implicit_dims=None): return idx_subs, temps + @check_radius def interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None): """ Generate equations interpolating an arbitrary expression into ``self``. @@ -226,6 +240,7 @@ def interpolate(self, expr, increment=False, self_subs={}, implicit_dims=None): """ return Interpolation(expr, increment, implicit_dims, self_subs, self) + @check_radius def inject(self, field, expr, implicit_dims=None): """ Generate equations injecting an arbitrary expression into a field. diff --git a/devito/types/grid.py b/devito/types/grid.py index de912844415..929134f9f66 100644 --- a/devito/types/grid.py +++ b/devito/types/grid.py @@ -261,7 +261,7 @@ def volume_cell(self): @property def spacing(self): """Spacing between grid points in m.""" - spacing = (np.array(self.extent) / (np.array(self.shape) - 1)).astype(self.dtype) + spacing = (np.array(self.extent) / (np.array(self.shape) - 1)) return as_tuple(spacing) @cached_property diff --git a/tests/test_dle.py b/tests/test_dle.py index 03c0b533c98..8e37947a361 100644 --- a/tests/test_dle.py +++ b/tests/test_dle.py @@ -709,7 +709,7 @@ def test_scheduling(self): """ grid = Grid(shape=(11, 11)) - u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=0) + u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=1) sf1 = SparseTimeFunction(name='s', grid=grid, npoint=1, nt=5) eqns = [Eq(u.forward, u + 1)] diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index 97d86c1759f..61a83acceb9 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -117,16 +117,15 @@ def test_precomputed_interpolation(r): origin = (0, 0) grid = Grid(shape=shape, origin=origin) - r = 2 # Constant for linear interpolation - # because we interpolate across 2 neighbouring points in each dimension def init(data): + # This is data with halo so need to shift to match the m.data expectations for i in range(data.shape[0]): for j in range(data.shape[1]): - data[i, j] = sin(grid.spacing[0]*i) + sin(grid.spacing[1]*j) + data[i, j] = sin(grid.spacing[0]*(i-r)) + sin(grid.spacing[1]*(j-r)) return data - m = Function(name='m', grid=grid, initializer=init, space_order=0) + m = Function(name='m', grid=grid, initializer=init, space_order=r) gridpoints, interpolation_coeffs = precompute_linear_interpolation(points, grid, origin, @@ -157,7 +156,7 @@ def test_precomputed_interpolation_time(r): r = 2 # Constant for linear interpolation # because we interpolate across 2 neighbouring points in each dimension - u = TimeFunction(name='u', grid=grid, space_order=0, save=5) + u = TimeFunction(name='u', grid=grid, space_order=r, save=5) for it in range(5): u.data[it, :] = it @@ -761,3 +760,16 @@ def test_inject_function(): for i in [0, 1, 3, 4]: for j in [0, 1, 3, 4]: assert u.data[1, i, j] == 0 + + +def test_interpolation_radius(): + nt = 11 + + grid = Grid(shape=(5, 5)) + u = TimeFunction(name="u", grid=grid, space_order=0) + src = SparseTimeFunction(name="src", grid=grid, nt=nt, npoint=1) + try: + src.interpolate(u) + assert False + except ValueError: + assert True diff --git a/tests/test_mpi.py b/tests/test_mpi.py index ab7092ba1a6..46b8e18fa57 100644 --- a/tests/test_mpi.py +++ b/tests/test_mpi.py @@ -1501,7 +1501,7 @@ def test_injection_wodup(self): """ grid = Grid(shape=(4, 4), extent=(3.0, 3.0)) - f = Function(name='f', grid=grid, space_order=0) + f = Function(name='f', grid=grid, space_order=1) f.data[:] = 0. coords = np.array([(0.5, 0.5), (0.5, 2.5), (2.5, 0.5), (2.5, 2.5)]) sf = SparseFunction(name='sf', grid=grid, npoint=len(coords), coordinates=coords) @@ -1536,7 +1536,7 @@ def test_injection_wodup_wtime(self): grid = Grid(shape=(4, 4), extent=(3.0, 3.0)) save = 3 - f = TimeFunction(name='f', grid=grid, save=save, space_order=0) + f = TimeFunction(name='f', grid=grid, save=save, space_order=1) f.data[:] = 0. coords = np.array([(0.5, 0.5), (0.5, 2.5), (2.5, 0.5), (2.5, 2.5)]) sf = SparseTimeFunction(name='sf', grid=grid, nt=save, @@ -1611,7 +1611,7 @@ def test_injection_dup(self): def test_interpolation_wodup(self): grid = Grid(shape=(4, 4), extent=(3.0, 3.0)) - f = Function(name='f', grid=grid, space_order=0) + f = Function(name='f', grid=grid, space_order=1) f.data[:] = 4. coords = [(0.5, 0.5), (0.5, 2.5), (2.5, 0.5), (2.5, 2.5)] sf = SparseFunction(name='sf', grid=grid, npoint=len(coords), coordinates=coords) diff --git a/tests/test_operator.py b/tests/test_operator.py index b2188d007cd..e685762a776 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -521,7 +521,7 @@ def test_sparsefunction_inject(self): Test injection of a SparseFunction into a Function """ grid = Grid(shape=(11, 11)) - u = Function(name='u', grid=grid, space_order=0) + u = Function(name='u', grid=grid, space_order=1) sf1 = SparseFunction(name='s', grid=grid, npoint=1) op = Operator(sf1.inject(u, expr=sf1)) @@ -542,7 +542,7 @@ def test_sparsefunction_interp(self): Test interpolation of a SparseFunction from a Function """ grid = Grid(shape=(11, 11)) - u = Function(name='u', grid=grid, space_order=0) + u = Function(name='u', grid=grid, space_order=1) sf1 = SparseFunction(name='s', grid=grid, npoint=1) op = Operator(sf1.interpolate(u)) @@ -563,7 +563,7 @@ def test_sparsetimefunction_interp(self): Test injection of a SparseTimeFunction into a TimeFunction """ grid = Grid(shape=(11, 11)) - u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=0) + u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=1) sf1 = SparseTimeFunction(name='s', grid=grid, npoint=1, nt=5) op = Operator(sf1.interpolate(u)) @@ -586,7 +586,7 @@ def test_sparsetimefunction_inject(self): Test injection of a SparseTimeFunction from a TimeFunction """ grid = Grid(shape=(11, 11)) - u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=0) + u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=1) sf1 = SparseTimeFunction(name='s', grid=grid, npoint=1, nt=5) op = Operator(sf1.inject(u, expr=3*sf1)) @@ -611,7 +611,7 @@ def test_sparsetimefunction_inject_dt(self): Test injection of the time deivative of a SparseTimeFunction into a TimeFunction """ grid = Grid(shape=(11, 11)) - u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=0) + u = TimeFunction(name='u', grid=grid, time_order=2, save=5, space_order=1) sf1 = SparseTimeFunction(name='s', grid=grid, npoint=1, nt=5, time_order=2)