Skip to content

Commit

Permalink
Merge pull request #2185 from devitocodes/custom-fd-v3
Browse files Browse the repository at this point in the history
api: fix symbolic coeffs for cross derivatives
  • Loading branch information
mloubout authored Aug 11, 2023
2 parents ca2960d + 36b6b21 commit c8e5415
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 57 deletions.
20 changes: 15 additions & 5 deletions devito/finite_differences/coefficients.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import sympy
import numpy as np
from cached_property import cached_property

from devito.finite_differences import generate_indices
from devito.finite_differences.tools import numeric_weights, symbolic_weights
from devito.tools import filter_ordered, as_tuple

__all__ = ['Coefficient', 'Substitutions', 'default_rules']
Expand Down Expand Up @@ -245,15 +245,25 @@ def generate_subs(deriv_order, function, index):
subs = {}

mapper = {dim: index}
# Get full range of indices and weights
indices, x0 = generate_indices(function, dim,
fd_order, side=None, x0=mapper)
sweights = symbolic_weights(function, deriv_order, indices, x0)

# Actual FD used indices and weights
if deriv_order == 1 and fd_order == 2:
fd_order = 1

indices, x0 = generate_indices(function, dim,
fd_order, side=None, x0=mapper)

coeffs = sympy.finite_diff_weights(deriv_order, indices, x0)[-1][-1]
coeffs = numeric_weights(deriv_order, indices, x0)

for (c, i) in zip(coeffs, indices):
subs.update({function._coeff_symbol(i, deriv_order, function, index): c})

for j in range(len(coeffs)):
subs.update({function._coeff_symbol
(indices[j], deriv_order, function, index): coeffs[j]})
# Set all unused weights to zero
subs.update({w: 0 for w in sweights if w not in subs})

return subs

Expand Down
3 changes: 3 additions & 0 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,9 @@ def _new_rawargs(self, *args, **kwargs):
kwargs.pop('is_commutative', None)
return self.func(*args, **kwargs)

def _coeff_symbol(self, *args, **kwargs):
return self.base._coeff_symbol(*args, **kwargs)


class diffify(object):

Expand Down
3 changes: 2 additions & 1 deletion devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def check_symbolic(func):
def wrapper(expr, *args, **kwargs):
if expr._uses_symbolic_coefficients:
expr_dict = expr.as_coefficients_dict()
if any(len(expr_dict) > 1 for item in expr_dict):
if any(v > 1 for k, v in expr_dict.items()):
raise NotImplementedError("Applying the chain rule to functions "
"with symbolic coefficients is not currently "
"supported")
Expand Down Expand Up @@ -337,6 +337,7 @@ def generate_indices_staggered(expr, dim, order, side=None, x0=None):
ind0 = expr.indices_ref[dim]
except AttributeError:
ind0 = start

if start != ind0:
if order < 2:
indices = [start - diff/2, start + diff/2]
Expand Down
86 changes: 37 additions & 49 deletions examples/seismic/tutorials/07_DRP_schemes.ipynb

Large diffs are not rendered by default.

19 changes: 17 additions & 2 deletions tests/test_symbolic_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ def test_default_rules(self, order, stagger):
staggered=staggered)
u1 = TimeFunction(name='u', grid=grid, time_order=order, space_order=order,
staggered=staggered, coefficients='symbolic')
eq0 = Eq(-u0.dx+u0.dt)

eq0 = Eq(u0.dt-u0.dx)
eq1 = Eq(u1.dt-u1.dx)
assert(eq0.evalf(_PRECISION).__repr__() == eq1.evalf(_PRECISION).__repr__())

assert(eq0.evaluate.evalf(_PRECISION).__repr__() ==
eq1.evaluate.evalf(_PRECISION).__repr__())

@pytest.mark.parametrize('expr, sorder, dorder, dim, weights, expected', [
('u.dx', 2, 1, 0, (-0.6, 0.1, 0.6),
Expand Down Expand Up @@ -374,3 +377,15 @@ def test_aggregate_w_custom_coeffs(self):
assert aggregated.args[1] == q.dx2

Operator([Eq(q.forward, expr)])(time_M=2) # noqa

def test_cross_derivs(self):
grid = Grid(shape=(11, 11, 11))
q = TimeFunction(name='q', grid=grid, space_order=8, time_order=2,
coefficients='symbolic')
q0 = TimeFunction(name='q', grid=grid, space_order=8, time_order=2)

eq0 = Eq(q0.forward, q0.dx.dy)
eq1 = Eq(q.forward, q.dx.dy)

assert(eq0.evaluate.evalf(_PRECISION).__repr__() ==
eq1.evaluate.evalf(_PRECISION).__repr__())

0 comments on commit c8e5415

Please sign in to comment.