Skip to content

Commit

Permalink
misc: Prevent builtins on transient functions
Browse files Browse the repository at this point in the history
  • Loading branch information
FabioLuporini committed Dec 23, 2024
1 parent 484c832 commit 2e23216
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 4 deletions.
8 changes: 7 additions & 1 deletion devito/builtins/arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import numpy as np

import devito as dv
from devito.builtins.utils import make_retval
from devito.builtins.utils import make_retval, check_builtins_args

__all__ = ['norm', 'sumall', 'sum', 'inner', 'mmin', 'mmax']


@dv.switchconfig(log_level='ERROR')
@check_builtins_args
def norm(f, order=2):
"""
Compute the norm of a Function.
Expand Down Expand Up @@ -41,6 +42,7 @@ def norm(f, order=2):


@dv.switchconfig(log_level='ERROR')
@check_builtins_args
def sum(f, dims=None):
"""
Compute the sum of the Function data over specified dimensions.
Expand Down Expand Up @@ -94,6 +96,7 @@ def sum(f, dims=None):


@dv.switchconfig(log_level='ERROR')
@check_builtins_args
def sumall(f):
"""
Compute the sum of all Function data.
Expand Down Expand Up @@ -123,6 +126,7 @@ def sumall(f):


@dv.switchconfig(log_level='ERROR')
@check_builtins_args
def inner(f, g):
"""
Inner product of two Functions.
Expand Down Expand Up @@ -177,6 +181,7 @@ def inner(f, g):


@dv.switchconfig(log_level='ERROR')
@check_builtins_args
def mmin(f):
"""
Retrieve the minimum.
Expand All @@ -200,6 +205,7 @@ def mmin(f):


@dv.switchconfig(log_level='ERROR')
@check_builtins_args
def mmax(f):
"""
Retrieve the maximum.
Expand Down
6 changes: 5 additions & 1 deletion devito/builtins/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import devito as dv
from devito.tools import as_tuple, as_list
from devito.builtins.utils import nbl_to_padsize, pad_outhalo
from devito.builtins.utils import check_builtins_args, nbl_to_padsize, pad_outhalo

__all__ = ['assign', 'smooth', 'gaussian_smooth', 'initialize_function']


@dv.switchconfig(log_level='ERROR')
@check_builtins_args
def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs):
"""
Assign a list of RHSs to a list of Functions.
Expand Down Expand Up @@ -85,6 +86,7 @@ def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs):
op(time_M=f._time_size)


@check_builtins_args
def smooth(f, g, axis=None):
"""
Smooth a Function through simple moving average.
Expand Down Expand Up @@ -114,6 +116,7 @@ def smooth(f, g, axis=None):
dv.Operator(dv.Eq(f, g.avg(dims=axis)), name='smoother')()


@check_builtins_args
def gaussian_smooth(f, sigma=1, truncate=4.0, mode='reflect'):
"""
Gaussian smooth function.
Expand Down Expand Up @@ -273,6 +276,7 @@ def buff(i, j):
return lhs, rhs, options


@check_builtins_args
def initialize_function(function, data, nbl, mapper=None, mode='constant',
name=None, pad_halo=True, **kwargs):
"""
Expand Down
28 changes: 27 additions & 1 deletion devito/builtins/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import numpy as np

import devito as dv
from devito.arch import Device
from devito.symbolics import uxreplace
from devito.tools import as_tuple

__all__ = ['make_retval', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args']
__all__ = ['make_retval', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args',
'check_builtins_args']


accumulator_mapper = {
Expand Down Expand Up @@ -131,3 +133,27 @@ def wrapper(*args, **kwargs):
return func(*processed, argmap=argmap, **kwargs)

return wrapper


def check_builtins_args(func):
"""
Perform checks on the arguments supplied to a builtin.
"""

@wraps(func)
def wrapper(*args, **kwargs):
platform = dv.configuration['platform']
if not isinstance(platform, Device):
return func(*args, **kwargs)

for i in args:
try:
if i.is_transient:
raise ValueError(f"Cannot apply `{func.__name__}` to transient "
f"function `{i.name}` on backend `{platform}`")
except AttributeError:
pass

return func(*args, **kwargs)

return wrapper
10 changes: 9 additions & 1 deletion tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
Dimension, MatrixSparseTimeFunction, SparseTimeFunction,
SubDimension, SubDomain, SubDomainSet, TimeFunction,
Operator, configuration, switchconfig, TensorTimeFunction,
Buffer)
Buffer, assign)
from devito.arch import get_gpu_info, get_cpu_info, Device, Cpu64
from devito.exceptions import InvalidArgument
from devito.ir import (Conditional, Expression, Section, FindNodes, FindSymbols,
Expand Down Expand Up @@ -1491,6 +1491,14 @@ def test_pickling(self):

assert str(op) == str(new_op)

def test_is_transient_w_builtins(self):
grid = Grid(shape=(4, 4))

f = Function(name='f', grid=grid, is_transient=True)

with pytest.raises(ValueError):
assign(f, 4)


class TestEdgeCases:

Expand Down

0 comments on commit 2e23216

Please sign in to comment.