Skip to content

Commit

Permalink
compiler: Better lowering of blocked MultiSubDimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
EdCaunt committed Jul 10, 2024
1 parent 9580e26 commit 5bfc7d7
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions devito/passes/clusters/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from devito.symbolics import retrieve_dimensions
from devito.tools import Bunch, frozendict, timed_pass
from devito.types import Eq, Symbol
from devito.types.dimension import BlockDimension
from devito.types.grid import MultiSubDimension

__all__ = ['generate_implicit']
Expand Down Expand Up @@ -226,21 +227,25 @@ def msdim(d):

@singledispatch
def _lower_msd(dim, cluster):
# Retval: (dynamic thickness mapper, iteration dimensions)
# Retval: (dynamic thickness mapper, iteration dimension)
return {}, None


@_lower_msd.register(MultiSubDimension)
def _(dim, cluster):
# Pull out the parent MultiSubDimension if blocked etc
i_dim = dim.implicit_dimension
mapper = {(dim.root, i): dim.functions[i_dim, mM]
for i, mM in enumerate(dim.bounds_indices)}
return mapper, i_dim


@_lower_msd.register(BlockDimension)
def _(dim, cluster):
# Pull out the parent MultiSubDimension
msd = [d for d in dim._defines if d.is_MultiSub]
assert len(msd) == 1 # Sanity check. MultiSubDimensions shouldn't be nested.
msd = msd.pop()

i_dim = msd.implicit_dimension
mapper = {(dim.root, i): msd.functions[i_dim, mM]
for i, mM in enumerate(msd.bounds_indices)}
return mapper, i_dim
return _lower_msd(msd, cluster)


def lower_msd(msdims, cluster):
Expand Down

0 comments on commit 5bfc7d7

Please sign in to comment.