Skip to content

Commit

Permalink
Merge pull request #2243 from devitocodes/fix-custom-coeff-lowering
Browse files Browse the repository at this point in the history
compiler: Patch custom coefficients
  • Loading branch information
FabioLuporini authored Oct 24, 2023
2 parents d566d31 + 698d6a8 commit ae40de0
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 11 deletions.
4 changes: 3 additions & 1 deletion devito/passes/iet/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,9 @@ def place_definitions(self, iet, globs=None, **kwargs):
includes = set()
if isinstance(iet, EntryFunction) and globs:
for i in sorted(globs, key=lambda f: f.name):
includes.add(self._alloc_array_on_global_mem(iet, i, storage))
v = self._alloc_array_on_global_mem(iet, i, storage)
if v:
includes.add(v)

iet, efuncs = self._inject_definitions(iet, storage)

Expand Down
14 changes: 5 additions & 9 deletions devito/symbolics/extended_sympy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
import sympy
from sympy import Expr, Integer, Function, Number, Tuple, sympify
from sympy import Expr, Function, Number, Tuple, sympify
from sympy.core.decorators import call_highest_priority

from devito.tools import (Pickable, as_tuple, is_integer, float2, float3, float4, # noqa
Expand Down Expand Up @@ -278,14 +278,10 @@ class ListInitializer(sympy.Expr, Pickable):
def __new__(cls, params):
args = []
for p in as_tuple(params):
if isinstance(p, str):
args.append(Symbol(p))
elif is_integer(p):
args.append(Integer(p))
elif not isinstance(p, Expr):
raise ValueError("`params` must be an iterable of Expr or str")
else:
args.append(p)
try:
args.append(sympify(p))
except sympy.SympifyError:
raise ValueError("Illegal param `%s`" % p)
obj = sympy.Expr.__new__(cls, *args)
obj.params = tuple(args)
return obj
Expand Down
12 changes: 11 additions & 1 deletion tests/test_unexpansion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np

from conftest import assert_structure, get_params, get_arrays, check_array
from devito import Buffer, Eq, Function, TimeFunction, Grid, Operator, cos, sin
from devito import (Buffer, Eq, Function, TimeFunction, Grid, Operator,
Substitutions, Coefficient, cos, sin)
from devito.types import Symbol


Expand Down Expand Up @@ -37,6 +38,15 @@ def test_fallback_to_default(self):
op.arguments(dt=1, time_M=10)
op.cfunction

def test_numeric_coeffs(self):
grid = Grid(shape=(11,), extent=(10.,))
u = Function(name='u', grid=grid, coefficients='symbolic', space_order=2)

coeffs = Substitutions(Coefficient(2, u, grid.dimensions[0], np.zeros(3)))

op = Operator(Eq(u, u.dx2, coefficients=coeffs), opt=({'expand': False},))
op.cfunction


class Test1Pass(object):

Expand Down

0 comments on commit ae40de0

Please sign in to comment.