-
Notifications
You must be signed in to change notification settings - Fork 234
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into open-closed-iterations
- Loading branch information
Showing
19 changed files
with
1,533 additions
and
882 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
""" | ||
Extended SymPy hierarchy. | ||
""" | ||
|
||
import sympy | ||
from sympy import Expr, Float | ||
from sympy.core.basic import _aresame | ||
from sympy.functions.elementary.trigonometric import TrigonometricFunction | ||
|
||
|
||
class UnevaluatedExpr(Expr): | ||
|
||
""" | ||
Use :class:`UnevaluatedExpr` in place of :class:`sympy.Expr` to prevent | ||
xreplace from unpicking factorizations. | ||
""" | ||
|
||
def xreplace(self, rule): | ||
if self in rule: | ||
return rule[self] | ||
elif rule: | ||
args = [] | ||
for a in self.args: | ||
try: | ||
args.append(a.xreplace(rule)) | ||
except AttributeError: | ||
args.append(a) | ||
args = tuple(args) | ||
if not _aresame(args, self.args): | ||
return self.func(*args, evaluate=False) | ||
return self | ||
|
||
|
||
class Mul(sympy.Mul, UnevaluatedExpr): | ||
pass | ||
|
||
|
||
class Add(sympy.Add, UnevaluatedExpr): | ||
pass | ||
|
||
|
||
class taylor_sin(TrigonometricFunction): | ||
|
||
""" | ||
Approximation of the sine function using a Taylor polynomial. | ||
""" | ||
|
||
@classmethod | ||
def eval(cls, arg): | ||
return eval_taylor_sin(arg) | ||
|
||
|
||
class taylor_cos(TrigonometricFunction): | ||
|
||
""" | ||
Approximation of the cosine function using a Taylor polynomial. | ||
""" | ||
|
||
@classmethod | ||
def eval(cls, arg): | ||
return 1.0 if arg == 0.0 else eval_taylor_cos(arg + 1.5708) | ||
|
||
|
||
class bhaskara_sin(TrigonometricFunction): | ||
|
||
""" | ||
Approximation of the sine function using a Bhaskara polynomial. | ||
""" | ||
|
||
@classmethod | ||
def eval(cls, arg): | ||
return eval_bhaskara_sin(arg) | ||
|
||
|
||
class bhaskara_cos(TrigonometricFunction): | ||
|
||
""" | ||
Approximation of the cosine function using a Bhaskara polynomial. | ||
""" | ||
|
||
@classmethod | ||
def eval(cls, arg): | ||
return 1.0 if arg == 0.0 else eval_bhaskara_sin(arg + 1.5708) | ||
|
||
|
||
# Utils | ||
|
||
def eval_bhaskara_sin(expr): | ||
return 16.0*expr*(3.1416-abs(expr))/(49.3483-4.0*abs(expr)*(3.1416-abs(expr))) | ||
|
||
|
||
def eval_taylor_sin(expr): | ||
v = expr + Mul(-1/6.0, | ||
Mul(expr, expr, expr, evaluate=False), | ||
1.0 + Mul(Mul(expr, expr, evaluate=False), -0.05, evaluate=False), | ||
evaluate=False) | ||
try: | ||
Float(expr) | ||
return v.doit() | ||
except (TypeError, ValueError): | ||
return v | ||
|
||
|
||
def eval_taylor_cos(expr): | ||
v = 1.0 + Mul(-0.5, | ||
Mul(expr, expr, evaluate=False), | ||
1.0 + Mul(expr, expr, -1/12.0, evaluate=False), | ||
evaluate=False) | ||
try: | ||
Float(expr) | ||
return v.doit() | ||
except (TypeError, ValueError): | ||
return v |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
""" | ||
In a DSE graph, a node is a temporary and an edge between two nodes n0 and n1 | ||
indicates that n1 reads n0. For example, given the excerpt: :: | ||
temp0 = a*b | ||
temp1 = temp0*c | ||
temp2 = temp0*d | ||
temp3 = temp1 + temp2 | ||
... | ||
A section of the ``temporaries graph`` looks as follows: :: | ||
temp0 ---> temp1 | ||
| | | ||
| | | ||
v v | ||
temp2 ---> temp3 | ||
Temporaries graph are used for symbolic as well as loop-level transformations. | ||
""" | ||
|
||
from collections import OrderedDict, namedtuple | ||
|
||
from sympy import (Eq, Indexed) | ||
|
||
from devito.dse.inspection import is_time_invariant, terminals | ||
from devito.dimension import t | ||
|
||
__all__ = ['temporaries_graph'] | ||
|
||
|
||
class Temporary(Eq): | ||
|
||
""" | ||
A special :class:`sympy.Eq` which keeps track of: :: | ||
- :class:`sympy.Eq` writing to ``self`` | ||
- :class:`sympy.Eq` reading from ``self`` | ||
A :class:`Temporary` is used as node in a temporaries graph. | ||
""" | ||
|
||
def __new__(cls, lhs, rhs, **kwargs): | ||
reads = kwargs.pop('reads', []) | ||
readby = kwargs.pop('readby', []) | ||
time_invariant = kwargs.pop('time_invariant', False) | ||
scope = kwargs.pop('scope', 0) | ||
obj = super(Temporary, cls).__new__(cls, lhs, rhs, **kwargs) | ||
obj._reads = set(reads) | ||
obj._readby = set(readby) | ||
obj._is_time_invariant = time_invariant | ||
obj._scope = scope | ||
return obj | ||
|
||
@property | ||
def reads(self): | ||
return self._reads | ||
|
||
@property | ||
def readby(self): | ||
return self._readby | ||
|
||
@property | ||
def is_time_invariant(self): | ||
return self._is_time_invariant | ||
|
||
@property | ||
def is_terminal(self): | ||
return len(self.readby) == 0 | ||
|
||
@property | ||
def is_tensor(self): | ||
return isinstance(self.lhs, Indexed) and self.lhs.rank > 0 | ||
|
||
@property | ||
def is_scalarizable(self): | ||
return not self.is_terminal and self.is_tensor | ||
|
||
@property | ||
def scope(self): | ||
return self._scope | ||
|
||
def construct(self, rule): | ||
""" | ||
Create a new temporary starting from ``self`` replacing symbols in | ||
the equation as specified by the dictionary ``rule``. | ||
""" | ||
reads = set(self.reads) - set(rule.keys()) | set(rule.values()) | ||
rhs = self.rhs.xreplace(rule) | ||
return Temporary(self.lhs, rhs, reads=reads, readby=self.readby, | ||
time_invariant=self.is_time_invariant, scope=self.scope) | ||
|
||
def __repr__(self): | ||
return "DSE(%s, reads=%s, readby=%s)" % (super(Temporary, self).__repr__(), | ||
str(self.reads), str(self.readby)) | ||
|
||
|
||
class TemporariesGraph(OrderedDict): | ||
|
||
""" | ||
A temporaries graph built on top of an OrderedDict. | ||
""" | ||
|
||
def space_dimensions(self): | ||
for v in self.values(): | ||
if v.is_terminal: | ||
found = v.lhs.free_symbols - {t, v.lhs.base.label} | ||
return tuple(sorted(found, key=lambda i: v.lhs.indices.index(i))) | ||
return () | ||
|
||
|
||
class Trace(OrderedDict): | ||
|
||
""" | ||
Assign a depth level to each temporary in a temporary graph. | ||
""" | ||
|
||
def __init__(self, root, graph, *args, **kwargs): | ||
super(Trace, self).__init__(*args, **kwargs) | ||
self._root = root | ||
self._compute(graph) | ||
|
||
def _compute(self, graph): | ||
if self.root not in graph: | ||
return | ||
to_visit = [(graph[self.root], 0)] | ||
while to_visit: | ||
temporary, level = to_visit.pop(0) | ||
self.__setitem__(temporary.lhs, level) | ||
to_visit.extend([(graph[i], level + 1) for i in temporary.reads]) | ||
|
||
@property | ||
def root(self): | ||
return self._root | ||
|
||
@property | ||
def length(self): | ||
return len(self) | ||
|
||
def intersect(self, other): | ||
return Trace(self.root, {}, [(k, v) for k, v in self.items() if k in other]) | ||
|
||
def union(self, other): | ||
return Trace(self.root, {}, [(k, v) for k, v in self.items() + other.items()]) | ||
|
||
|
||
def temporaries_graph(temporaries, scope=0): | ||
""" | ||
Create a temporaries graph given a list of :class:`sympy.Eq`. | ||
""" | ||
|
||
mapper = OrderedDict() | ||
Node = namedtuple('Node', ['rhs', 'reads', 'readby', 'time_invariant']) | ||
|
||
for lhs, rhs in [i.args for i in temporaries]: | ||
reads = {i for i in terminals(rhs) if i in mapper} | ||
mapper[lhs] = Node(rhs, reads, set(), is_time_invariant(rhs, mapper)) | ||
for i in mapper[lhs].reads: | ||
assert i in mapper, "Illegal Flow" | ||
mapper[i].readby.add(lhs) | ||
|
||
nodes = [Temporary(k, v.rhs, reads=v.reads, readby=v.readby, | ||
time_invariant=v.time_invariant, scope=scope) | ||
for k, v in mapper.items()] | ||
|
||
return TemporariesGraph([(i.lhs, i) for i in nodes]) |
Oops, something went wrong.