Skip to content

Commit

Permalink
misc: cleanup test and fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Sep 25, 2024
1 parent c7360e2 commit ef76cc4
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 25 deletions.
46 changes: 26 additions & 20 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,22 +235,19 @@ def __call__(self, x0=None, fd_order=None, side=None, method=None, weights=None)
except AttributeError:
raise TypeError("fd_order incompatible with dimensions")

# In case this was called on a cross derivative we need to propagate
# the call to the nested derivative
if isinstance(self.expr, Derivative):
_fd_orders = {k: v for k, v in _fd_order.items() if k in self.expr.dims}
_x0s = {k: v for k, v in _x0.items() if k in self.expr.dims and
k not in self.dims}
new_expr = self.expr(x0=_x0s, fd_order=_fd_orders, side=side,
method=method, weights=weights)
# In case this was called on a perfect cross-derivative `u.dxdy`
# we need to propagate the call to the nested derivative
x0s = self.filter_dims(self.expr.filter_dims(_x0), neg=True)
expr = self.expr(x0=x0s, fd_order=self.expr.filter_dims(_fd_order),
side=side, method=method)
else:
new_expr = self.expr
expr = self.expr

_fd_order = tuple(v for k, v in _fd_order.items() if k in self.dims)
_fd_order = DimensionTuple(*_fd_order, getters=self.dims)
_fd_order = self.filter_dims(_fd_order, as_tuple=True)

return self._rebuild(fd_order=_fd_order, x0=_x0, side=side, method=method,
weights=weights, expr=new_expr)
weights=weights, expr=expr)

def _rebuild(self, *args, **kwargs):
kwargs['preprocessed'] = True
Expand Down Expand Up @@ -305,17 +302,30 @@ def _xreplace(self, subs):

# Resolve nested derivatives
dsubs = {k: v for k, v in subs.items() if isinstance(k, Derivative)}
new_expr = self.expr.xreplace(dsubs)
expr = self.expr.xreplace(dsubs)

subs = self._ppsubs + (subs,) # Postponed substitutions
return self._rebuild(subs=subs, expr=new_expr), True
return self._rebuild(subs=subs, expr=expr), True

@cached_property
def _metadata(self):
ret = [self.dims] + [getattr(self, i) for i in self.__rkwargs__]
ret.append(self.expr.staggered or (None,))
return tuple(ret)

def filter_dims(self, col, as_tuple=False, neg=False):
"""
Filter collectiion to only keep the derivative's dimensions as keys.
"""
if neg:
filtered = {k: v for k, v in col.items() if k not in self.dims}
else:
filtered = {k: v for k, v in col.items() if k in self.dims}
if as_tuple:
return DimensionTuple(*filtered.values(), getters=self.dims)
else:
return filtered

@property
def dims(self):
return self._dims
Expand Down Expand Up @@ -436,13 +446,9 @@ def _eval_fd(self, expr, **kwargs):
"""
# Step 1: Evaluate non-derivative x0. We currently enforce a simple 2nd order
# interpolation to avoid very expensive finite differences on top of it
x0_interp = {}
x0_deriv = {}
for d, v in self.x0.items():
if d in self.dims:
x0_deriv[d] = v
elif not d.is_Time:
x0_interp[d] = v
x0_deriv = self.filter_dims(self.x0)
x0_interp = {d: v for d, v in self.x0.items()
if d not in x0_deriv and not d.is_Time}

if x0_interp and self.method == 'FD':
expr = interp_for_fd(expr, x0_interp, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def generate_fd_shortcuts(dims, so, to=0):
from devito.finite_differences.derivative import Derivative

def diff_f(expr, deriv_order, dims, fd_order, side=None, **kwargs):
# Spearate dimension to always have cross derivatives return nested
# derivatives.
# Separate dimensions to always have cross derivatives return nested
# derivatives. E.g `u.dxdy -> u.dx.dy`
dims = as_tuple(dims)
deriv_order = as_tuple(deriv_order)
fd_order = as_tuple(fd_order)
Expand Down
4 changes: 2 additions & 2 deletions devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def _lower_exprs(expressions, subs):

# Handle Array
if isinstance(f, Array) and f.initvalue is not None:
initv = [_lower_exprs(i, subs) for i in f.initvalue]
initvalue = [_lower_exprs(i, subs) for i in f.initvalue]
# TODO: fix rebuild to avoid new name
f = f._rebuild(name='%si' % f.name, initvalue=initv)
f = f._rebuild(name='%si' % f.name, initvalue=initvalue)

mapper[i] = f.indexed[indices]
# Add dimensions map to the mapper in case dimensions are used
Expand Down
5 changes: 5 additions & 0 deletions tests/test_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,11 @@ def test_xderiv_x0(self):
- f.dx(x0=x+h_x/2).dy(x0=y+h_y/2).evaluate
assert simplify(expr) == 0

# Check x0 is correctly set
dfdxdx = f.dx(x0=x+h_x/2).dx(x0=x-h_x/2)
assert dict(dfdxdx.x0) == {x: x-h_x/2}
assert dict(dfdxdx.expr.x0) == {x: x+h_x/2}

def test_fd_new_side(self):
grid = Grid((10,))
u = Function(name="u", grid=grid, space_order=4)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def test_shifted_curl_of_vector(shift, ndim):
dorder = order or 4
for drv in drvs:
assert drv.expr in f
assert drv.fd_order == dorder
assert drv.fd_order == (dorder,)
if shift is None:
assert drv.x0 == {}
else:
Expand Down

0 comments on commit ef76cc4

Please sign in to comment.