Skip to content

Commit

Permalink
Merge pull request #2248 from devitocodes/patch-unexp-cross-derivs
Browse files Browse the repository at this point in the history
compiler: Patch symbolic coefficients over cross derivatives
  • Loading branch information
FabioLuporini authored Oct 27, 2023
2 parents 4b0e39a + a30b392 commit 7a235ed
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 21 deletions.
26 changes: 25 additions & 1 deletion devito/finite_differences/coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,33 @@ def generate_subs(deriv_order, function, index):
# NOTE: Do we want to throw a warning if the same arg has
# been provided twice?
args_provided = list(set(args_provided))
not_provided = [i for i in args_present if i not in frozenset(args_provided)]

rules = {}
not_provided = []
for i0 in args_present:
if any(i0 == i1 for i1 in args_provided):
# Perfect match, as expected by the legacy custom coeffs API
continue

# TODO: to make cross-derivs work, we must relax `not_provided` by
# checking not for equality, but rather for inclusion. This is ugly,
# but basically a major revamp is the only alternative... and for now,
# it does the trick
mapper = {}
deriv_order, expr, dim = i0
try:
for k, v in subs.rules.items():
ofs, do, f, d = k.args
if deriv_order == do and dim is d and f in expr._functions:
mapper[k.func(ofs, do, expr, d)] = v
except AttributeError:
assert subs is None

if mapper:
rules.update(mapper)
else:
not_provided.append(i0)

for i in not_provided:
rules = {**rules, **generate_subs(*i)}

Expand Down
23 changes: 16 additions & 7 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,13 @@ def _symbolic_functions(self):
def _uses_symbolic_coefficients(self):
return bool(self._symbolic_functions)

@cached_property
def _coeff_symbol(self, *args, **kwargs):
if self._uses_symbolic_coefficients:
return W
else:
raise ValueError("Couldn't find any symbolic coefficients")

def _eval_at(self, func):
if not func.is_Staggered:
# Cartesian grid, do no waste time
Expand Down Expand Up @@ -327,6 +334,10 @@ def highest_priority(DiffOp):
return sorted(DiffOp._args_diff, key=prio, reverse=True)[0]


# Abstract symbol representing a symbolic coefficient
W = sympy.Function('W')


class DifferentiableOp(Differentiable):

__sympy_class__ = None
Expand Down Expand Up @@ -606,12 +617,13 @@ def __init_finalize__(self, *args, **kwargs):
assert isinstance(d, StencilDimension) and d.symbolic_size == len(weights)
assert isinstance(weights, (list, tuple, np.ndarray))

try:
self._spacings = set().union(*[i.find(Spacing) for i in weights])
except AttributeError:
self._spacing = set()
# Normalize `weights`
weights = tuple(sympy.sympify(i) for i in weights)

self._spacings = set().union(*[i.find(Spacing) for i in weights])

kwargs['scope'] = 'constant'
kwargs['initvalue'] = weights

super().__init_finalize__(*args, **kwargs)

Expand Down Expand Up @@ -766,9 +778,6 @@ def _new_rawargs(self, *args, **kwargs):
kwargs.pop('is_commutative', None)
return self.func(*args, **kwargs)

def _coeff_symbol(self, *args, **kwargs):
return self.base._coeff_symbol(*args, **kwargs)


class diffify(object):

Expand Down
8 changes: 0 additions & 8 deletions devito/types/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,14 +198,6 @@ def coefficients(self):
"""Form of the coefficients of the function."""
return self._coefficients

@cached_property
def _coeff_symbol(self):
if self.coefficients == 'symbolic':
return sympy.Function('W')
else:
raise ValueError("Function was not declared with symbolic "
"coefficients.")

@cached_property
def shape(self):
"""
Expand Down
101 changes: 101 additions & 0 deletions tests/test_symbolic_coefficients.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ def test_aggregate_w_custom_coeffs(self):

def test_cross_derivs(self):
grid = Grid(shape=(11, 11, 11))

q = TimeFunction(name='q', grid=grid, space_order=8, time_order=2,
coefficients='symbolic')
q0 = TimeFunction(name='q', grid=grid, space_order=8, time_order=2)
Expand All @@ -389,3 +390,103 @@ def test_cross_derivs(self):

assert(eq0.evaluate.evalf(_PRECISION).__repr__() ==
eq1.evaluate.evalf(_PRECISION).__repr__())

def test_cross_derivs_imperfect(self):
grid = Grid(shape=(11, 11, 11))

p = TimeFunction(name='p', grid=grid, space_order=4, time_order=2,
coefficients='symbolic')
q = TimeFunction(name='q', grid=grid, space_order=4, time_order=2,
coefficients='symbolic')

p0 = TimeFunction(name='p', grid=grid, space_order=4, time_order=2)
q0 = TimeFunction(name='q', grid=grid, space_order=4, time_order=2)

eq0 = Eq(q0.forward, (q0.dx + p0.dx).dy)
eq1 = Eq(q.forward, (q.dx + p.dx).dy)

assert(eq0.evaluate.evalf(_PRECISION).__repr__() ==
eq1.evaluate.evalf(_PRECISION).__repr__())

def test_nested_subs(self):
grid = Grid(shape=(11, 11))
x, y = grid.dimensions
hx, hy = grid.spacing_symbols

p = TimeFunction(name='p', grid=grid, space_order=2,
coefficients='symbolic')

coeffs0 = np.array([100, 100, 100])
coeffs1 = np.array([200, 200, 200])

subs = Substitutions(Coefficient(1, p, x, coeffs0),
Coefficient(1, p, y, coeffs1))

eq = Eq(p.forward, p.dx.dy, coefficients=subs)

mul = lambda e: sp.Mul(e, 200, evaluate=False)
term0 = mul(p*100 +
p.subs(x, x-hx)*100 +
p.subs(x, x+hx)*100)
term1 = mul(p.subs(y, y-hy)*100 +
p.subs({x: x-hx, y: y-hy})*100 +
p.subs({x: x+hx, y: y-hy})*100)
term2 = mul(p.subs(y, y+hy)*100 +
p.subs({x: x-hx, y: y+hy})*100 +
p.subs({x: x+hx, y: y+hy})*100)

# `str` simply because some objects are of type EvalDerivative
assert str(eq.evaluate.rhs) == str(term0 + term1 + term2)

def test_compound_subs(self):
grid = Grid(shape=(11,))
x, = grid.dimensions
hx, = grid.spacing_symbols

f = Function(name='f', grid=grid, space_order=2)
p = TimeFunction(name='p', grid=grid, space_order=2,
coefficients='symbolic')

coeffs0 = np.array([100, 100, 100])

subs = Substitutions(Coefficient(1, p, x, coeffs0))

eq = Eq(p.forward, (f*p).dx, coefficients=subs)

term0 = f*p*100
term1 = (f*p*100).subs(x, x-hx)
term2 = (f*p*100).subs(x, x+hx)

# `str` simply because some objects are of type EvalDerivative
assert str(eq.evaluate.rhs) == str(term0 + term1 + term2)

def test_compound_nested_subs(self):
grid = Grid(shape=(11, 11))
x, y = grid.dimensions
hx, hy = grid.spacing_symbols

f = Function(name='f', grid=grid, space_order=2)
p = TimeFunction(name='p', grid=grid, space_order=2,
coefficients='symbolic')

coeffs0 = np.array([100, 100, 100])
coeffs1 = np.array([200, 200, 200])

subs = Substitutions(Coefficient(1, p, x, coeffs0),
Coefficient(1, p, y, coeffs1))

eq = Eq(p.forward, (f*p.dx).dy, coefficients=subs)

mul = lambda e, i: sp.Mul(f.subs(y, y+i*hy), e, 200, evaluate=False)
term0 = mul(p*100 +
p.subs(x, x-hx)*100 +
p.subs(x, x+hx)*100, 0)
term1 = mul(p.subs(y, y-hy)*100 +
p.subs({x: x-hx, y: y-hy})*100 +
p.subs({x: x+hx, y: y-hy})*100, -1)
term2 = mul(p.subs(y, y+hy)*100 +
p.subs({x: x-hx, y: y+hy})*100 +
p.subs({x: x+hx, y: y+hy})*100, 1)

# `str` simply because some objects are of type EvalDerivative
assert str(eq.evaluate.rhs) == str(term0 + term1 + term2)
24 changes: 19 additions & 5 deletions tests/test_unexpansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_backward_dt2(self):
assert_structure(op, ['t,x,y'], 't,x,y')


class TestSymbolicCoefficients(object):
class TestSymbolicCoeffs(object):

def test_fallback_to_default(self):
grid = Grid(shape=(8, 8, 8))
Expand All @@ -39,13 +39,27 @@ def test_fallback_to_default(self):
op.cfunction

def test_numeric_coeffs(self):
grid = Grid(shape=(11,), extent=(10.,))
grid = Grid(shape=(11, 11), extent=(10., 10.))

u = Function(name='u', grid=grid, coefficients='symbolic', space_order=2)
v = Function(name='v', grid=grid, coefficients='symbolic', space_order=2)

coeffs = Substitutions(Coefficient(2, u, grid.dimensions[0], np.zeros(3)))
coeffs = Substitutions(Coefficient(2, u, grid.dimensions[0], np.zeros(3)),
Coefficient(2, u, grid.dimensions[1], np.zeros(3)))

op = Operator(Eq(u, u.dx2, coefficients=coeffs), opt=({'expand': False},))
op.cfunction
opt = ('advanced', {'expand': False})

# Pure derivative
Operator(Eq(u, u.dx2, coefficients=coeffs), opt=opt).cfunction

# Mixed derivative
Operator(Eq(u, u.dx.dx, coefficients=coeffs), opt=opt).cfunction

# Non-perfect mixed derivative
Operator(Eq(u, (u.dx + v.dx).dx, coefficients=coeffs), opt=opt).cfunction

# Compound expression
Operator(Eq(u, (v*u.dx).dy, coefficients=coeffs), opt=opt).cfunction


class Test1Pass(object):
Expand Down

0 comments on commit 7a235ed

Please sign in to comment.