Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

types: Remove increments from SubDimension names #2257

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9dc4682
WIP
FabioLuporini Oct 18, 2023
da6975b
types: Started work on functionality to resolve clashing dimension na…
EdCaunt Nov 1, 2023
d1a73d7
types: Tidy up and investigation into failures
EdCaunt Nov 1, 2023
0426692
types: Unclashing now works, but is too eager
EdCaunt Nov 2, 2023
f2322c5
types: Consistent failure pattern now
EdCaunt Nov 2, 2023
15cb6a9
tests: Tightened relational_classes test
EdCaunt Nov 2, 2023
bec6aa7
misc: tidy up
EdCaunt Nov 2, 2023
52a9a6e
compiler: Widened dimension search to include ConditionalDimension co…
EdCaunt Nov 2, 2023
99c4d49
compiler: Limited dimension extraction to indices and included Condit…
EdCaunt Nov 3, 2023
39aa594
dsl: Modifications to uxreplace
EdCaunt Nov 3, 2023
0442dc2
dsl: uxreplace now applies to ConditionalDimension conditions
EdCaunt Nov 3, 2023
3b2a7d9
misc: Tidy up
EdCaunt Nov 3, 2023
b0308a1
tests: Added test for new uxreplace functionality
EdCaunt Nov 3, 2023
f86178b
misc: Tidy up
EdCaunt Nov 3, 2023
e95a51b
dsl: Tweaked uxreplace
EdCaunt Nov 3, 2023
59dd3ae
dsl: Added deep mode to uxreplace
EdCaunt Nov 6, 2023
b946c7c
dsl: uxreplace now uses underlying attributes rather than properties …
EdCaunt Nov 6, 2023
0418235
mpi: Extended uxreplace deep option to halotouch
EdCaunt Nov 6, 2023
2f64dfc
dsl: added dxreplace for uxreplace + dimension substitution
EdCaunt Nov 7, 2023
6477cd8
compiler: Moved dimension unclashing after creation of LoweredEq
EdCaunt Nov 8, 2023
907d1d4
dsl: dxreplace now replaces dimensions inside ConditionalDimension fa…
EdCaunt Nov 8, 2023
b29d087
misc: flake8
EdCaunt Nov 9, 2023
6e3f274
types: Restored increments on SubDimension bounds
EdCaunt Nov 10, 2023
66abbc5
types: Updated dimension tagging for SubDomainSet
EdCaunt Nov 10, 2023
60a5d0d
tests: Updated tree structures to match new dimension names
EdCaunt Nov 14, 2023
fc322c1
examples: Updated the generated code printed in the convection example
EdCaunt Nov 14, 2023
6eeb17b
examples: Updated the SubDomains notebook
EdCaunt Nov 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from collections.abc import Iterable
from itertools import groupby

from sympy import sympify

from devito.symbolics import retrieve_indexed, uxreplace, retrieve_dimensions
from devito.symbolics import (retrieve_indexed, uxreplace, dxreplace,
retrieve_dimensions)
from devito.tools import Ordering, as_tuple, flatten, filter_sorted, filter_ordered
from devito.ir.support import pull_dims
from devito.types import Dimension, IgnoreDimSort
from devito.types.basic import AbstractFunction

__all__ = ['dimension_sort', 'lower_exprs']
__all__ = ['dimension_sort', 'lower_exprs', 'separate_dimensions']


def dimension_sort(expr):
Expand Down Expand Up @@ -147,3 +150,50 @@ def lower_exprs(expressions, **kwargs):
else:
assert len(processed) == 1
return processed.pop()


def separate_dimensions(expressions):
"""
Rename Dimensions with clashing names within expressions.
"""
resolutions = {}
count = {} # Keep track of increments on dim names
processed = []
for e in expressions:
# Just want dimensions which appear in the expression
# Dimensions in indices
dims = set().union(*tuple(set(i.function.dimensions)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Un-necessary extra "set" work, just set().union(*[i.function.dimensions ...])

for i in retrieve_indexed(e)))
# Dimensions in conditions and ConditionalDimension parents
dims = dims.union(*tuple(pull_dims(d.condition, flag=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you probably want to unclutter here and move most of this logic into pull_dims

ideally, all you need to do here would be:

dims = pull_dims(e, flag=...)

# Sort for groupby
dims = sorted(dims, key=lambda x: x.name)

because here you're performing a search which, in theory, is just one, atomic step of separate_dimensions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIrc, what's required here is more picky than pull_dims (pull dims selects more dimensions than are needed). It just wants dimensions in the indices, conditional parent, and conditional condition

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, no need for tuple

for d in dims if d.is_Conditional
and d.condition is not None),
set(d.parent for d in dims if d.is_Conditional))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the whole thing seems to be a bit intricated retrieve_dimensions(e, deep=True) + e.implicit_dims should give you all dimensions.
Also why are we ignoring the conditionals with a factor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that retrieve_dimensions collects too many things, not too few (even with flag=False). It wants to retrieve only indices in the equation, conditional dimension parent, and dimensions in the conditional dimension condition. Good point on the factor, that should probably be added in.

Fabio pointed out that separate_equations probably wants to be after creation of LoweredEQ, so may need to rethink this anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It wants to retrieve only indices in the equation

Why would a dimsion use as a free symbol rather than an index be ignored? It doesn't change the issue it's gonna be the wrong value if not changed. Any time both the conditional and the parent are used together irrespective of how is a clash.

# Sort for groupby
dims = sorted(dims, key=lambda x: x.name)

# Group dimensions by matching names
groups = tuple(tuple(g) for n, g in groupby(dims, key=lambda x: x.name))
clashes = tuple(g for g in groups if len(g) > 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you only need generator no need for tuple


subs = {}
for c in clashes:
# Preferentially rename dims that aren't root dims
rdims = tuple(d for d in c if d.is_Root)
ddims = tuple(d for d in c if not d.is_Root)

for d in (ddims + rdims)[:-1]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is safe, a root dimension should never need to have its name change so this should only loop through ddims and raise an error if there is still clashes

if d in resolutions.keys():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and resolutions[d] not in clashes

Need to make sure you don't replace a clash by another but that case gets tricky because you can't set resoltuion[d] anymore or it would replace the existing one by a different one, not sure what the best way is for this very corner case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could always check that the incremented dimension name isn't in dims and incrementing until it is not?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also no need for .keys()

subs[d] = resolutions[d]
else:
try:
subs[d] = d._rebuild(d.name+str(count[d.name]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

devito standard is formating '%s%s' % (d.name, count[d.name]) not string concatenation

count[d.name] += 1
except KeyError:
subs[d] = d._rebuild(d.name+'0')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, '%s0' % d.name

count[d.name] = 1
resolutions[d] = subs[d]

processed.append(dxreplace(e, subs))

return processed
1 change: 0 additions & 1 deletion devito/ir/support/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def detect_accesses(exprs):
for e in as_tuple(exprs):
other_dims.update(i for i in e.free_symbols if isinstance(i, Dimension))
other_dims.update(e.implicit_dims)
other_dims = filter_sorted(other_dims)
mapper[None] = Stencil([(i, 0) for i in other_dims])

return mapper
Expand Down
7 changes: 5 additions & 2 deletions devito/mpi/halo_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from devito import configuration
from devito.data import CORE, OWNED, LEFT, CENTER, RIGHT
from devito.ir.support import Forward, Scope
from devito.symbolics.manipulation import _uxreplace_registry
from devito.symbolics.manipulation import _uxreplace_registry, _dxreplace_registry
from devito.tools import (Reconstructable, Tag, as_tuple, filter_ordered, flatten,
frozendict, is_integer)
from devito.types import Grid
Expand Down Expand Up @@ -590,7 +590,7 @@ def __eq__(self, other):
func = Reconstructable._rebuild


def _uxreplace_dispatch_haloscheme(hs0, rule):
def _uxreplace_dispatch_haloscheme(hs0, rule, mode='ux'):
changed = False
hs = hs0
for f, hse0 in hs0.fmapper.items():
Expand Down Expand Up @@ -636,3 +636,6 @@ def _uxreplace_dispatch_haloscheme(hs0, rule):

_uxreplace_registry.register(HaloTouch,
{HaloScheme: _uxreplace_dispatch_haloscheme})

_dxreplace_registry.register(HaloTouch,
{HaloScheme: _uxreplace_dispatch_haloscheme})
5 changes: 4 additions & 1 deletion devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from devito.data import default_allocator
from devito.exceptions import InvalidOperator
from devito.logger import debug, info, perf, warning, is_log_enabled_for
from devito.ir.equations import LoweredEq, lower_exprs
from devito.ir.equations import LoweredEq, lower_exprs, separate_dimensions
from devito.ir.clusters import ClusterGroup, clusterize
from devito.ir.iet import (Callable, CInterface, EntryFunction, FindSymbols, MetaCall,
derive_parameters, iet_build)
Expand Down Expand Up @@ -328,6 +328,9 @@ def _lower_exprs(cls, expressions, **kwargs):
# "True" lowering (indexification, shifting, ...)
expressions = lower_exprs(expressions, **kwargs)

# Resolve clashing dimension names
expressions = separate_dimensions(expressions)

processed = [LoweredEq(i) for i in expressions]

return processed
Expand Down
96 changes: 77 additions & 19 deletions devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
from devito.symbolics.extended_sympy import DefFunction, rfunc
from devito.symbolics.queries import q_leaf
from devito.symbolics.search import retrieve_indexed, retrieve_functions
from devito.tools import as_list, as_tuple, flatten, split, transitive_closure
from devito.tools import (as_list, as_tuple, flatten, split, transitive_closure,
frozendict)
from devito.types.basic import Basic
from devito.types.array import ComponentAccess
from devito.types.equation import Eq
from devito.types.relational import Le, Lt, Gt, Ge
from devito.types.relational import Le, Lt, Gt, Ge, AbstractRel
from devito.types.dimension import ConditionalDimension

__all__ = ['xreplace_indices', 'pow_to_mul', 'indexify', 'subs_op_args',
'normalize_args', 'uxreplace', 'Uxmapper', 'reuse_if_untouched',
'evalrel']
'evalrel', 'dxreplace']


def uxreplace(expr, rule):
Expand All @@ -43,10 +45,19 @@ def uxreplace(expr, rule):
Finally, `uxreplace` supports Reconstructable objects, that is, it searches
for replacement opportunities inside the Reconstructable's `__rkwargs__`.
"""
return _uxreplace(expr, rule)[0]
return _uxreplace(expr, rule, mode='ux')[0]


def _uxreplace(expr, rule):
def dxreplace(expr, rule):
"""
As `uxreplace`, albeit with systematic replacement of dimensions in the
expression, including those contained within attached `ConditionalDimension`s
and `SubDomain`s, which would not be touched by the standard `uxreplace`.
"""
return _uxreplace(expr, rule, mode='dx')[0]


def _uxreplace(expr, rule, mode='ux'):
if expr in rule:
v = rule[expr]
if not isinstance(v, dict):
Expand All @@ -69,53 +80,85 @@ def _uxreplace(expr, rule):
changed = False

if rule:
eargs, flag = _uxreplace_dispatch(eargs, rule)
eargs, flag = _uxreplace_dispatch(eargs, rule, mode=mode)
args.extend(eargs)

changed |= flag

# If a Reconstructable object, we need to parse the kwargs as well
if _uxreplace_registry.dispatchable(expr):
v = {i: getattr(expr, i) for i in expr.__rkwargs__}
kwargs, flag = _uxreplace_dispatch(v, rule)
# Select the correct registry for the replacement mode specified
try:
registry = _replace_registries[mode]
except KeyError:
raise ValueError("Mode '%s' does not match any replacement mode"
% mode)

# If a Reconstructable object, we need to parse args and kwargs
if registry.dispatchable(expr):
if not args:
try:
v = [getattr(expr, i) for i in expr.__rargs__]
except AttributeError:
# Reconstructable has no required args
v = []
args, aflag = _uxreplace_dispatch(v, rule, mode=mode)
else:
aflag = False

try:
# Get the underlying attribute originally used to build
# the object rather than the property, as the latter can
# sometimes have default values which do not match the
# default for the class
v = {i: getattr(expr, "_"+i) for i in expr.__rkwargs__}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'_%s' % i

What if there isn't one and there is only the self.rkwargs one directly .function for any AbstractFunction

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not quite sure I understand?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some type have attribute such as self.function that are in rkwargs but self._function doesn't exist

except AttributeError:
# Reconstructable has no required kwargs
v = {}
kwargs, kwflag = _uxreplace_dispatch(v, rule, mode=mode)
else:
kwargs, flag = {}, False
aflag = False
kwargs, kwflag = {}, False
flag = aflag | kwflag
changed |= flag

if changed:
return _uxreplace_handle(expr, args, kwargs), True

return expr, False


@singledispatch
def _uxreplace_dispatch(unknown, rule):
def _uxreplace_dispatch(unknown, rule, mode='ux'):
return unknown, False


@_uxreplace_dispatch.register(Basic)
def _(expr, rule):
return _uxreplace(expr, rule)
def _(expr, rule, mode='ux'):
return _uxreplace(expr, rule, mode=mode)


@_uxreplace_dispatch.register(AbstractRel)
def _(expr, rule, mode='ux'):
return _uxreplace(expr, rule, mode=mode)


@_uxreplace_dispatch.register(tuple)
@_uxreplace_dispatch.register(Tuple)
@_uxreplace_dispatch.register(list)
def _(iterable, rule):
def _(iterable, rule, mode='ux'):
ret = []
changed = False
for a in iterable:
ax, flag = _uxreplace(a, rule)
ax, flag = _uxreplace(a, rule, mode=mode)
ret.append(ax)
changed |= flag
return iterable.__class__(ret), changed


@_uxreplace_dispatch.register(dict)
def _(mapper, rule):
def _(mapper, rule, mode='ux'):
ret = {}
changed = False
for k, v in mapper.items():
vx, flag = _uxreplace_dispatch(v, rule)
vx, flag = _uxreplace_dispatch(v, rule, mode=mode)
ret[k] = vx
changed |= flag
return ret, changed
Expand Down Expand Up @@ -180,13 +223,28 @@ def register(self, cls, rkwargs_callback_mapper=None):
_uxreplace_dispatch.register(kls, callback)

def dispatchable(self, obj):
# If not deep, ignore objects associated with deep uxreplace
return isinstance(obj, tuple(self))


# Registry for default uxreplace
_uxreplace_registry = UxreplaceRegistry()
_uxreplace_registry.register(Eq)
_uxreplace_registry.register(DefFunction)
_uxreplace_registry.register(ComponentAccess)
# Classes which only want uxreplacing when deep=True specified
# _uxreplace_registry.register(ConditionalDimension, deep=True)

# Registry for dimension uxreplace
_dxreplace_registry = UxreplaceRegistry()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename as ReplaceRegistry ?

_dxreplace_registry.register(Eq)
_dxreplace_registry.register(DefFunction)
_dxreplace_registry.register(ComponentAccess)
_dxreplace_registry.register(ConditionalDimension)

# Create a dict of registries
_replace_registries = frozendict({'ux': _uxreplace_registry,
'dx': _dxreplace_registry})


class Uxmapper(dict):
Expand Down
4 changes: 4 additions & 0 deletions devito/types/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def is_const(self):
def root(self):
return self

@property
def is_Root(self):
return self == self.root
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is self.root


@cached_property
def bound_symbols(self):
candidates = [self.symbolic_min, self.symbolic_max, self.symbolic_size,
Expand Down
40 changes: 17 additions & 23 deletions devito/types/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def __init__(self, shape, extent=None, origin=None, dimensions=None,

# Initialize SubDomains
subdomains = tuple(i for i in (Domain(), Interior(), *as_tuple(subdomains)))
for counter, i in enumerate(subdomains):
i.__subdomain_finalize__(self, counter=counter)
for i in subdomains:
i.__subdomain_finalize__(self)
self._subdomains = subdomains

self._origin = as_tuple(origin or tuple(0. for _ in self.shape))
Expand Down Expand Up @@ -486,7 +486,6 @@ def __subdomain_finalize__(self, grid, **kwargs):
# Create the SubDomain's SubDimensions
sub_dimensions = []
sdshape = []
counter = kwargs.get('counter', 0) - 1
for k, v, s in zip(self.define(grid.dimensions).keys(),
self.define(grid.dimensions).values(), grid.shape):
if isinstance(v, Dimension):
Expand All @@ -495,33 +494,28 @@ def __subdomain_finalize__(self, grid, **kwargs):
else:
try:
# Case ('middle', int, int)
side, thickness_left, thickness_right = v
side, tkn_left, tkn_right = v
if side != 'middle':
raise ValueError("Expected side 'middle', not `%s`" % side)
sub_dimensions.append(SubDimension.middle('i%d%s' %
(counter, k.name),
k, thickness_left,
thickness_right))
thickness = s-thickness_left-thickness_right
sdshape.append(thickness)
sub_dimensions.append(
SubDimension.middle(k.name, k, tkn_left, tkn_right)
)
tkn = s-tkn_left-tkn_right
sdshape.append(tkn)
except ValueError:
side, thickness = v
side, tkn = v
if side == 'left':
if s-thickness < 0:
if s-tkn < 0:
raise ValueError("Maximum thickness of dimension %s "
"is %d, not %d" % (k.name, s, thickness))
sub_dimensions.append(SubDimension.left('i%d%s' %
(counter, k.name),
k, thickness))
sdshape.append(thickness)
"is %d, not %d" % (k.name, s, tkn))
sub_dimensions.append(SubDimension.left(k.name, k, tkn))
sdshape.append(tkn)
elif side == 'right':
if s-thickness < 0:
if s-tkn < 0:
raise ValueError("Maximum thickness of dimension %s "
"is %d, not %d" % (k.name, s, thickness))
sub_dimensions.append(SubDimension.right('i%d%s' %
(counter, k.name),
k, thickness))
sdshape.append(thickness)
"is %d, not %d" % (k.name, s, tkn))
sub_dimensions.append(SubDimension.right(k.name, k, tkn))
sdshape.append(tkn)
else:
raise ValueError("Expected sides 'left|right', not `%s`" % side)

Expand Down
3 changes: 3 additions & 0 deletions devito/types/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ class AbstractRel(object):
"""
Abstract mixin class for objects subclassing sympy.Relational.
"""

__rargs__ = ('lhs', 'rhs')

@property
def negated(self):
return ops.get(self.func)(*self.args)
Expand Down
Loading