Skip to content

Commit

Permalink
Merge pull request #1229 from devitocodes/indexify-fix
Browse files Browse the repository at this point in the history
indexify fix
  • Loading branch information
mloubout authored Apr 14, 2020
2 parents 17722a2 + a8b7ab0 commit decce70
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 32 deletions.
21 changes: 2 additions & 19 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,26 +305,9 @@ def _lower_exprs(cls, expressions, **kwargs):
else:
dimension_map = {}

mapper = {}

# Handle Functions (typical case)
for f in retrieve_functions(expr):
# Get spacing symbols for replacement
spacings = [i.spacing for i in f.dimensions]

# Only keep the ones used as indices
spacings = [s for i, s in enumerate(spacings)
if s.free_symbols.intersection(f.args[i].free_symbols)]

# Substitution for each index
subs = {**{s: 1 for s in spacings}, **dimension_map}

# Introduce shifting to align with the computational domain,
# and apply substitutions
indices = [(a - i + o).xreplace(subs)
for a, i, o in zip(f.args, f.origin, f._size_nodomain.left)]

mapper[f] = f.indexed[indices]
mapper = {f: f.indexify(lshift=True, subs=dimension_map)
for f in retrieve_functions(expr)}

# Handle Indexeds (from index notation)
for i in retrieve_indexed(expr):
Expand Down
23 changes: 10 additions & 13 deletions devito/types/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ def symbolic_shape(self):
for i, j, k in zip(domain, halo, padding))
return DimensionTuple(*ret, getters=self.dimensions)

@property
@cached_property
def indexed(self):
"""The wrapped IndexedData object."""
return IndexedData(self.name, shape=self.shape, function=self.function)
Expand Down Expand Up @@ -919,25 +919,22 @@ def _data_alignment(self):
"""
return default_allocator().guaranteed_alignment

def indexify(self, indices=None):
def indexify(self, indices=None, lshift=False, subs=None):
"""Create a types.Indexed from the current object."""
if indices is not None:
return Indexed(self.indexed, *indices)

# Get spacing symbols for replacement
spacings = [i.spacing for i in self.dimensions]

# Only keep the ones used as indices.
spacings = [s for i, s in enumerate(spacings)
if s.free_symbols.intersection(self.args[i].free_symbols)]

# Substitution for each index
subs = {s: 1 for s in spacings}
# Substitution for each index (spacing only used in own dimension)
subs = subs or {}
subs = [{**{d.spacing: 1, -d.spacing: -1}, **subs} for d in self.dimensions]

# Add halo shift
shift = self._size_nodomain.left if lshift else tuple([0]*len(self.dimensions))
# Indices after substitutions
indices = [(a - o).xreplace(subs) for a, o in zip(self.args, self.origin)]
indices = [(a - o + f).xreplace(s) for a, o, f, s in
zip(self.args, self.origin, shift, subs)]

return Indexed(self.indexed, *indices)
return self.indexed[indices]

def __getitem__(self, index):
"""Shortcut for ``self.indexed[index]``."""
Expand Down
17 changes: 17 additions & 0 deletions tests/test_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from devito import (ConditionalDimension, Grid, Function, TimeFunction, SparseFunction, # noqa
Eq, Operator, Constant, Dimension, SubDimension, switchconfig)
from devito.ir.iet import Expression, Iteration, FindNodes, retrieve_iteration_tree
from devito.symbolics import indexify, retrieve_functions
from devito.types import Array


Expand Down Expand Up @@ -547,6 +548,22 @@ def test_spacial_subsampling(self):
# Verify that u2[x,y]= u[2*x, 2*y]
assert np.allclose(u.data[:-1, 0:-1:2, 0:-1:2], u2.data[:-1, :, :])

def test_time_subsampling_fd(self):
nt = 19
grid = Grid(shape=(11, 11))
x, y = grid.dimensions
time = grid.time_dim

factor = 4
time_subsampled = ConditionalDimension('t_sub', parent=time, factor=factor)
usave = TimeFunction(name='usave', grid=grid, save=(nt+factor-1)//factor,
time_dim=time_subsampled, time_order=2)

dx2 = [indexify(i) for i in retrieve_functions(usave.dt2.evaluate)]
assert dx2 == [usave[time_subsampled - 1, x, y],
usave[time_subsampled + 1, x, y],
usave[time_subsampled, x, y]]

def test_subsampled_fd(self):
"""
Test that the FD shortcuts are handled correctly with ConditionalDimensions
Expand Down

0 comments on commit decce70

Please sign in to comment.