diff --git a/devito/builtins/arithmetic.py b/devito/builtins/arithmetic.py index 0e1e7fbcdc..93fbc917b5 100644 --- a/devito/builtins/arithmetic.py +++ b/devito/builtins/arithmetic.py @@ -1,12 +1,13 @@ import numpy as np import devito as dv -from devito.builtins.utils import make_retval +from devito.builtins.utils import make_retval, check_builtins_args __all__ = ['norm', 'sumall', 'sum', 'inner', 'mmin', 'mmax'] @dv.switchconfig(log_level='ERROR') +@check_builtins_args def norm(f, order=2): """ Compute the norm of a Function. @@ -41,6 +42,7 @@ def norm(f, order=2): @dv.switchconfig(log_level='ERROR') +@check_builtins_args def sum(f, dims=None): """ Compute the sum of the Function data over specified dimensions. @@ -94,6 +96,7 @@ def sum(f, dims=None): @dv.switchconfig(log_level='ERROR') +@check_builtins_args def sumall(f): """ Compute the sum of all Function data. @@ -123,6 +126,7 @@ def sumall(f): @dv.switchconfig(log_level='ERROR') +@check_builtins_args def inner(f, g): """ Inner product of two Functions. @@ -177,6 +181,7 @@ def inner(f, g): @dv.switchconfig(log_level='ERROR') +@check_builtins_args def mmin(f): """ Retrieve the minimum. @@ -200,6 +205,7 @@ def mmin(f): @dv.switchconfig(log_level='ERROR') +@check_builtins_args def mmax(f): """ Retrieve the maximum. diff --git a/devito/builtins/initializers.py b/devito/builtins/initializers.py index dbe271eac1..7855ee75a2 100644 --- a/devito/builtins/initializers.py +++ b/devito/builtins/initializers.py @@ -2,12 +2,13 @@ import devito as dv from devito.tools import as_tuple, as_list -from devito.builtins.utils import nbl_to_padsize, pad_outhalo +from devito.builtins.utils import check_builtins_args, nbl_to_padsize, pad_outhalo __all__ = ['assign', 'smooth', 'gaussian_smooth', 'initialize_function'] @dv.switchconfig(log_level='ERROR') +@check_builtins_args def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs): """ Assign a list of RHSs to a list of Functions. @@ -85,6 +86,7 @@ def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs): op(time_M=f._time_size) +@check_builtins_args def smooth(f, g, axis=None): """ Smooth a Function through simple moving average. @@ -114,6 +116,7 @@ def smooth(f, g, axis=None): dv.Operator(dv.Eq(f, g.avg(dims=axis)), name='smoother')() +@check_builtins_args def gaussian_smooth(f, sigma=1, truncate=4.0, mode='reflect'): """ Gaussian smooth function. @@ -273,6 +276,7 @@ def buff(i, j): return lhs, rhs, options +@check_builtins_args def initialize_function(function, data, nbl, mapper=None, mode='constant', name=None, pad_halo=True, **kwargs): """ diff --git a/devito/builtins/utils.py b/devito/builtins/utils.py index 32f59c731f..df7c7ddca3 100644 --- a/devito/builtins/utils.py +++ b/devito/builtins/utils.py @@ -3,10 +3,12 @@ import numpy as np import devito as dv +from devito.arch import Device from devito.symbolics import uxreplace from devito.tools import as_tuple -__all__ = ['make_retval', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args'] +__all__ = ['make_retval', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args', + 'check_builtins_args'] accumulator_mapper = { @@ -131,3 +133,27 @@ def wrapper(*args, **kwargs): return func(*processed, argmap=argmap, **kwargs) return wrapper + + +def check_builtins_args(func): + """ + Perform checks on the arguments supplied to a builtin. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + platform = dv.configuration['platform'] + if not isinstance(platform, Device): + return func(*args, **kwargs) + + for i in args: + try: + if i.is_transient: + raise ValueError(f"Cannot apply `{func.__name__}` to transient " + f"function `{i.name}` on backend `{platform}`") + except AttributeError: + pass + + return func(*args, **kwargs) + + return wrapper diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 1be9ec3fbe..a22ab93df4 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -9,7 +9,7 @@ Dimension, MatrixSparseTimeFunction, SparseTimeFunction, SubDimension, SubDomain, SubDomainSet, TimeFunction, Operator, configuration, switchconfig, TensorTimeFunction, - Buffer) + Buffer, assign) from devito.arch import get_gpu_info, get_cpu_info, Device, Cpu64 from devito.exceptions import InvalidArgument from devito.ir import (Conditional, Expression, Section, FindNodes, FindSymbols, @@ -1491,6 +1491,14 @@ def test_pickling(self): assert str(op) == str(new_op) + def test_is_transient_w_builtins(self): + grid = Grid(shape=(4, 4)) + + f = Function(name='f', grid=grid, is_transient=True) + + with pytest.raises(ValueError): + assign(f, 4) + class TestEdgeCases: