-
Notifications
You must be signed in to change notification settings - Fork 230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
compiler: Improve IndexDerivatives lowering #2183
Changes from all commits
83f892b
a8e1743
a7cba2d
354cf5c
379c571
f00228a
1189dcb
34daf46
d3c1d22
7a34949
09b57eb
7ee89ef
6a9e320
30243f2
dceaf1b
bccca77
87d0d2c
9d60dc8
dc9bac4
c9d3f96
aa79d6d
7d70ff9
7b039c6
998c9b3
d246cb2
e7ff141
8ecb7cc
a046fa3
756e0a1
3e7270b
4b7c213
11ff50f
44b752e
cc4d9ad
83d0924
0d18e85
c9f006a
f1466d1
6ed33ee
e10d99f
ee501d5
1827ad4
729bc3f
5242a1d
3f17791
da40050
6420df3
e4f186b
9041b7c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -207,9 +207,11 @@ def generic_derivative(expr, dim, fd_order, deriv_order, matvec=direct, x0=None, | |
matvec, x0, symbolic, expand) | ||
|
||
|
||
def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, symbolic, expand): | ||
def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, symbolic, | ||
expand): | ||
# The stencil indices | ||
indices, x0 = generate_indices(expr, dim, fd_order, side=side, matvec=matvec, x0=x0) | ||
indices, x0 = generate_indices(expr, dim, fd_order, side=side, matvec=matvec, | ||
x0=x0) | ||
|
||
# Finite difference weights from Taylor approximation given these positions | ||
if symbolic: | ||
|
@@ -221,15 +223,24 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, symbolic | |
weights = [sympify(w).evalf(_PRECISION) for w in weights] | ||
|
||
# Transpose the FD, if necessary | ||
if matvec: | ||
indices = indices.scale(matvec.val) | ||
indices = indices.scale(matvec.val) | ||
|
||
# Shift index due to staggering, if any | ||
indices = indices.shift(-(expr.indices_ref[dim] - dim)) | ||
|
||
# The user may wish to restrict expansion to selected derivatives | ||
if callable(expand): | ||
expand = expand(dim) | ||
|
||
if not expand and indices.expr is not None: | ||
weights = Weights(name='w', dimensions=indices.free_dim, initvalue=weights) | ||
|
||
if matvec == transpose: | ||
# For homogenity, always generate e.g. `x + i0` rather than `x - i0` | ||
# for transpose and `x + i0` for direct | ||
indices = indices.transpose() | ||
weights = weights._subs(indices.free_dim, -indices.free_dim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was that way, just like the indices above, but I had to change it because... I don't remember exactly |
||
|
||
# Inject the StencilDimension | ||
# E.g. `x + i*h_x` into `f(x)` s.t. `f(x + i*h_x)` | ||
expr = expr._subs(dim, indices.expr) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
Forward, Interval, IntervalGroup, IterationSpace, | ||
DataSpace, Guards, Properties, Scope, detect_accesses, | ||
detect_io, normalize_properties, normalize_syncs, | ||
sdims_min, sdims_max) | ||
minimum, maximum) | ||
from devito.mpi.halo_scheme import HaloScheme, HaloTouch | ||
from devito.symbolics import estimate_cost | ||
from devito.tools import as_tuple, flatten, frozendict, infer_dtype | ||
|
@@ -52,13 +52,7 @@ def __init__(self, exprs, ispace=None, guards=None, properties=None, syncs=None, | |
|
||
# Normalize properties | ||
properties = Properties(properties or {}) | ||
for d in ispace.itdimensions: | ||
properties = properties.add(d) | ||
for i in properties: | ||
for d in as_tuple(i): | ||
if d not in ispace.itdimensions: | ||
properties = properties.drop(d) | ||
self._properties = properties | ||
self._properties = tailor_properties(properties, ispace) | ||
|
||
self._halo_scheme = halo_scheme | ||
|
||
|
@@ -85,10 +79,7 @@ def from_clusters(cls, *clusters): | |
|
||
guards = root.guards | ||
|
||
properties = {} | ||
for c in clusters: | ||
for d, v in c.properties.items(): | ||
properties[d] = normalize_properties(properties.get(d, v), v) | ||
properties = reduce_properties(clusters) | ||
|
||
try: | ||
syncs = normalize_syncs(*[c.syncs for c in clusters]) | ||
|
@@ -213,12 +204,10 @@ def is_dense(self): | |
# at most PARALLEL_IF_PVT). This is a quick and easy check so we try it first | ||
try: | ||
pset = {PARALLEL, PARALLEL_IF_PVT} | ||
grid = self.grid | ||
for d in grid.dimensions: | ||
if not any(pset & v for k, v in self.properties.items() | ||
if d in k._defines): | ||
raise ValueError | ||
return True | ||
target = set(self.grid.dimensions) | ||
dims = {d for d in self.properties if d._defines & target} | ||
if any(pset & self.properties[d] for d in dims): | ||
return True | ||
except ValueError: | ||
pass | ||
|
||
|
@@ -276,8 +265,8 @@ def dspace(self): | |
continue | ||
|
||
intervals = [Interval(d, | ||
min([sdims_min(i) for i in offs]), | ||
max([sdims_max(i) for i in offs])) | ||
min([minimum(i) for i in offs]), | ||
max([maximum(i) for i in offs])) | ||
for d, offs in v.items()] | ||
intervals = IntervalGroup(intervals) | ||
|
||
|
@@ -418,15 +407,21 @@ def scope(self): | |
def ispace(self): | ||
return self._ispace | ||
|
||
@cached_property | ||
def properties(self): | ||
return tailor_properties(reduce_properties(self), self.ispace) | ||
|
||
@cached_property | ||
def guards(self): | ||
"""The guards of each Cluster in self.""" | ||
return tuple(i.guards for i in self) | ||
|
||
@cached_property | ||
def syncs(self): | ||
"""The synchronization operations of each Cluster in self.""" | ||
return tuple(i.syncs for i in self) | ||
""" | ||
A view of the ClusterGroup's synchronization operations. | ||
""" | ||
return normalize_syncs(*[c.syncs for c in self]) | ||
|
||
@cached_property | ||
def dspace(self): | ||
|
@@ -461,3 +456,26 @@ def meta(self): | |
The data type and the data space of the ClusterGroup. | ||
""" | ||
return (self.dtype, self.dspace) | ||
|
||
|
||
# *** Utils | ||
|
||
def reduce_properties(clusters): | ||
properties = {} | ||
for c in clusters: | ||
for d, v in c.properties.items(): | ||
properties[d] = normalize_properties(properties.get(d, v), v) | ||
|
||
return Properties(properties) | ||
|
||
|
||
def tailor_properties(properties, ispace): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. tailor_* better to add docstrings that would make clear what is the use of each one. |
||
for d in ispace.itdimensions: | ||
properties = properties.add(d) | ||
|
||
for i in properties: | ||
for d in as_tuple(i): | ||
if d not in ispace.itdimensions: | ||
properties = properties.drop(d) | ||
|
||
return properties |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Homogeneity