diff --git a/devito/passes/clusters/implicit.py b/devito/passes/clusters/implicit.py index 100ba4842a..60b435dbde 100644 --- a/devito/passes/clusters/implicit.py +++ b/devito/passes/clusters/implicit.py @@ -11,6 +11,7 @@ from devito.symbolics import retrieve_dimensions from devito.tools import Bunch, frozendict, timed_pass from devito.types import Eq, Symbol +from devito.types.dimension import BlockDimension from devito.types.grid import MultiSubDimension __all__ = ['generate_implicit'] @@ -226,21 +227,25 @@ def msdim(d): @singledispatch def _lower_msd(dim, cluster): - # Retval: (dynamic thickness mapper, iteration dimensions) + # Retval: (dynamic thickness mapper, iteration dimension) return {}, None @_lower_msd.register(MultiSubDimension) def _(dim, cluster): - # Pull out the parent MultiSubDimension if blocked etc + i_dim = dim.implicit_dimension + mapper = {(dim.root, i): dim.functions[i_dim, mM] + for i, mM in enumerate(dim.bounds_indices)} + return mapper, i_dim + + +@_lower_msd.register(BlockDimension) +def _(dim, cluster): + # Pull out the parent MultiSubDimension msd = [d for d in dim._defines if d.is_MultiSub] assert len(msd) == 1 # Sanity check. MultiSubDimensions shouldn't be nested. msd = msd.pop() - - i_dim = msd.implicit_dimension - mapper = {(dim.root, i): msd.functions[i_dim, mM] - for i, mM in enumerate(msd.bounds_indices)} - return mapper, i_dim + return _lower_msd(msd, cluster) def lower_msd(msdims, cluster):