diff --git a/devito/builtins/arithmetic.py b/devito/builtins/arithmetic.py index 0e1e7fbcdc..2994f6a346 100644 --- a/devito/builtins/arithmetic.py +++ b/devito/builtins/arithmetic.py @@ -1,12 +1,13 @@ import numpy as np import devito as dv -from devito.builtins.utils import make_retval +from devito.builtins.utils import make_retval, check_args __all__ = ['norm', 'sumall', 'sum', 'inner', 'mmin', 'mmax'] @dv.switchconfig(log_level='ERROR') +@check_args def norm(f, order=2): """ Compute the norm of a Function. @@ -41,6 +42,7 @@ def norm(f, order=2): @dv.switchconfig(log_level='ERROR') +@check_args def sum(f, dims=None): """ Compute the sum of the Function data over specified dimensions. @@ -94,6 +96,7 @@ def sum(f, dims=None): @dv.switchconfig(log_level='ERROR') +@check_args def sumall(f): """ Compute the sum of all Function data. @@ -123,6 +126,7 @@ def sumall(f): @dv.switchconfig(log_level='ERROR') +@check_args def inner(f, g): """ Inner product of two Functions. @@ -177,6 +181,7 @@ def inner(f, g): @dv.switchconfig(log_level='ERROR') +@check_args def mmin(f): """ Retrieve the minimum. @@ -200,6 +205,7 @@ def mmin(f): @dv.switchconfig(log_level='ERROR') +@check_args def mmax(f): """ Retrieve the maximum. diff --git a/devito/builtins/initializers.py b/devito/builtins/initializers.py index dbe271eac1..eb1379efcf 100644 --- a/devito/builtins/initializers.py +++ b/devito/builtins/initializers.py @@ -2,12 +2,13 @@ import devito as dv from devito.tools import as_tuple, as_list -from devito.builtins.utils import nbl_to_padsize, pad_outhalo +from devito.builtins.utils import check_args, nbl_to_padsize, pad_outhalo __all__ = ['assign', 'smooth', 'gaussian_smooth', 'initialize_function'] @dv.switchconfig(log_level='ERROR') +@check_args def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs): """ Assign a list of RHSs to a list of Functions. @@ -85,6 +86,7 @@ def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs): op(time_M=f._time_size) +@check_args def smooth(f, g, axis=None): """ Smooth a Function through simple moving average. @@ -114,6 +116,7 @@ def smooth(f, g, axis=None): dv.Operator(dv.Eq(f, g.avg(dims=axis)), name='smoother')() +@check_args def gaussian_smooth(f, sigma=1, truncate=4.0, mode='reflect'): """ Gaussian smooth function. @@ -273,6 +276,7 @@ def buff(i, j): return lhs, rhs, options +@check_args def initialize_function(function, data, nbl, mapper=None, mode='constant', name=None, pad_halo=True, **kwargs): """ diff --git a/devito/builtins/utils.py b/devito/builtins/utils.py index 32f59c731f..9c5c05aacd 100644 --- a/devito/builtins/utils.py +++ b/devito/builtins/utils.py @@ -6,7 +6,8 @@ from devito.symbolics import uxreplace from devito.tools import as_tuple -__all__ = ['make_retval', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args'] +__all__ = ['make_retval', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args', + 'check_args'] accumulator_mapper = { @@ -131,3 +132,23 @@ def wrapper(*args, **kwargs): return func(*processed, argmap=argmap, **kwargs) return wrapper + + +def check_args(func): + """ + Perform checks on the arguments supplied to a builtin. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + for i in args: + try: + if i.is_transient: + raise ValueError(f"Cannot apply `{func.__name__}` to transient " + "function `{i.name}`") + except AttributeError: + pass + + return func(*args, **kwargs) + + return wrapper diff --git a/tests/test_builtins.py b/tests/test_builtins.py index 102875d35a..d1ebed5d15 100644 --- a/tests/test_builtins.py +++ b/tests/test_builtins.py @@ -518,3 +518,11 @@ def test_issue_1863(self): assert type(v1) is np.int32 assert type(v2) is np.float32 assert type(v3) is np.float64 + + def test_is_transient(self): + grid = Grid(shape=(4, 4)) + + f = Function(name='f', grid=grid, is_transient=True) + + with pytest.raises(ValueError): + assign(f, 4)