Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc: Prevent builtins on transient functions #2506

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion devito/builtins/arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import numpy as np

import devito as dv
from devito.builtins.utils import make_retval
from devito.builtins.utils import make_retval, check_args

__all__ = ['norm', 'sumall', 'sum', 'inner', 'mmin', 'mmax']


@dv.switchconfig(log_level='ERROR')
@check_args
def norm(f, order=2):
"""
Compute the norm of a Function.
Expand Down Expand Up @@ -41,6 +42,7 @@ def norm(f, order=2):


@dv.switchconfig(log_level='ERROR')
@check_args
def sum(f, dims=None):
"""
Compute the sum of the Function data over specified dimensions.
Expand Down Expand Up @@ -94,6 +96,7 @@ def sum(f, dims=None):


@dv.switchconfig(log_level='ERROR')
@check_args
def sumall(f):
"""
Compute the sum of all Function data.
Expand Down Expand Up @@ -123,6 +126,7 @@ def sumall(f):


@dv.switchconfig(log_level='ERROR')
@check_args
def inner(f, g):
"""
Inner product of two Functions.
Expand Down Expand Up @@ -177,6 +181,7 @@ def inner(f, g):


@dv.switchconfig(log_level='ERROR')
@check_args
def mmin(f):
"""
Retrieve the minimum.
Expand All @@ -200,6 +205,7 @@ def mmin(f):


@dv.switchconfig(log_level='ERROR')
@check_args
def mmax(f):
"""
Retrieve the maximum.
Expand Down
6 changes: 5 additions & 1 deletion devito/builtins/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@

import devito as dv
from devito.tools import as_tuple, as_list
from devito.builtins.utils import nbl_to_padsize, pad_outhalo
from devito.builtins.utils import check_args, nbl_to_padsize, pad_outhalo

__all__ = ['assign', 'smooth', 'gaussian_smooth', 'initialize_function']


@dv.switchconfig(log_level='ERROR')
@check_args
def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs):
"""
Assign a list of RHSs to a list of Functions.
Expand Down Expand Up @@ -85,6 +86,7 @@ def assign(f, rhs=0, options=None, name='assign', assign_halo=False, **kwargs):
op(time_M=f._time_size)


@check_args
def smooth(f, g, axis=None):
"""
Smooth a Function through simple moving average.
Expand Down Expand Up @@ -114,6 +116,7 @@ def smooth(f, g, axis=None):
dv.Operator(dv.Eq(f, g.avg(dims=axis)), name='smoother')()


@check_args
def gaussian_smooth(f, sigma=1, truncate=4.0, mode='reflect'):
"""
Gaussian smooth function.
Expand Down Expand Up @@ -273,6 +276,7 @@ def buff(i, j):
return lhs, rhs, options


@check_args
def initialize_function(function, data, nbl, mapper=None, mode='constant',
name=None, pad_halo=True, **kwargs):
"""
Expand Down
23 changes: 22 additions & 1 deletion devito/builtins/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from devito.symbolics import uxreplace
from devito.tools import as_tuple

__all__ = ['make_retval', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args']
__all__ = ['make_retval', 'nbl_to_padsize', 'pad_outhalo', 'abstract_args',
'check_args']


accumulator_mapper = {
Expand Down Expand Up @@ -131,3 +132,23 @@ def wrapper(*args, **kwargs):
return func(*processed, argmap=argmap, **kwargs)

return wrapper


def check_args(func):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check_builtins_args

"""
Perform checks on the arguments supplied to a builtin.
"""

@wraps(func)
def wrapper(*args, **kwargs):
for i in args:
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should only be an issue on Device platform, CPU transient do hold the data

if i.is_transient:
raise ValueError(f"Cannot apply `{func.__name__}` to transient "
"function `{i.name}`")
except AttributeError:
pass

return func(*args, **kwargs)

return wrapper
8 changes: 8 additions & 0 deletions tests/test_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,11 @@ def test_issue_1863(self):
assert type(v1) is np.int32
assert type(v2) is np.float32
assert type(v3) is np.float64

def test_is_transient(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mark device

grid = Grid(shape=(4, 4))

f = Function(name='f', grid=grid, is_transient=True)

with pytest.raises(ValueError):
assign(f, 4)
Loading