diff --git a/devito/compiler.py b/devito/compiler.py index f846f6703b..e74ad81e2f 100644 --- a/devito/compiler.py +++ b/devito/compiler.py @@ -1,12 +1,12 @@ from functools import partial from os import environ, getuid, mkdir, path from tempfile import gettempdir +from time import time import numpy.ctypeslib as npct from cgen import Pragma from codepy.jit import extension_file_from_string from codepy.toolchain import GCCToolchain - from devito.logger import log __all__ = ['get_tmp_dir', 'get_compiler_from_env', @@ -243,9 +243,11 @@ def jit_compile(ccode, basename, compiler=GNUCompiler): """ src_file = "%s.%s" % (basename, compiler.src_ext) lib_file = "%s.%s" % (basename, compiler.lib_ext) - log("%s: Compiling %s" % (compiler, src_file)) + tic = time() extension_file_from_string(toolchain=compiler, ext_file=lib_file, source_string=ccode, source_name=src_file) + toc = time() + log("%s: compiled %s [%.2f s]" % (compiler, src_file, toc-tic)) return lib_file diff --git a/devito/dse/__init__.py b/devito/dse/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/devito/dse/extended_sympy.py b/devito/dse/extended_sympy.py new file mode 100644 index 0000000000..597bb41197 --- /dev/null +++ b/devito/dse/extended_sympy.py @@ -0,0 +1,113 @@ +""" +Extended SymPy hierarchy. +""" + +import sympy +from sympy import Expr, Float +from sympy.core.basic import _aresame +from sympy.functions.elementary.trigonometric import TrigonometricFunction + + +class UnevaluatedExpr(Expr): + + """ + Use :class:`UnevaluatedExpr` in place of :class:`sympy.Expr` to prevent + xreplace from unpicking factorizations. + """ + + def xreplace(self, rule): + if self in rule: + return rule[self] + elif rule: + args = [] + for a in self.args: + try: + args.append(a.xreplace(rule)) + except AttributeError: + args.append(a) + args = tuple(args) + if not _aresame(args, self.args): + return self.func(*args, evaluate=False) + return self + + +class Mul(sympy.Mul, UnevaluatedExpr): + pass + + +class Add(sympy.Add, UnevaluatedExpr): + pass + + +class taylor_sin(TrigonometricFunction): + + """ + Approximation of the sine function using a Taylor polynomial. + """ + + @classmethod + def eval(cls, arg): + return eval_taylor_sin(arg) + + +class taylor_cos(TrigonometricFunction): + + """ + Approximation of the cosine function using a Taylor polynomial. + """ + + @classmethod + def eval(cls, arg): + return 1.0 if arg == 0.0 else eval_taylor_cos(arg + 1.5708) + + +class bhaskara_sin(TrigonometricFunction): + + """ + Approximation of the sine function using a Bhaskara polynomial. + """ + + @classmethod + def eval(cls, arg): + return eval_bhaskara_sin(arg) + + +class bhaskara_cos(TrigonometricFunction): + + """ + Approximation of the cosine function using a Bhaskara polynomial. + """ + + @classmethod + def eval(cls, arg): + return 1.0 if arg == 0.0 else eval_bhaskara_sin(arg + 1.5708) + + +# Utils + +def eval_bhaskara_sin(expr): + return 16.0*expr*(3.1416-abs(expr))/(49.3483-4.0*abs(expr)*(3.1416-abs(expr))) + + +def eval_taylor_sin(expr): + v = expr + Mul(-1/6.0, + Mul(expr, expr, expr, evaluate=False), + 1.0 + Mul(Mul(expr, expr, evaluate=False), -0.05, evaluate=False), + evaluate=False) + try: + Float(expr) + return v.doit() + except (TypeError, ValueError): + return v + + +def eval_taylor_cos(expr): + v = 1.0 + Mul(-0.5, + Mul(expr, expr, evaluate=False), + 1.0 + Mul(expr, expr, -1/12.0, evaluate=False), + evaluate=False) + try: + Float(expr) + return v.doit() + except (TypeError, ValueError): + return v diff --git a/devito/dse/graph.py b/devito/dse/graph.py new file mode 100644 index 0000000000..3a3aa59e00 --- /dev/null +++ b/devito/dse/graph.py @@ -0,0 +1,166 @@ +""" +In a DSE graph, a node is a temporary and an edge between two nodes n0 and n1 +indicates that n1 reads n0. For example, given the excerpt: :: + + temp0 = a*b + temp1 = temp0*c + temp2 = temp0*d + temp3 = temp1 + temp2 + ... + +A section of the ``temporaries graph`` looks as follows: :: + + temp0 ---> temp1 + | | + | | + v v + temp2 ---> temp3 + +Temporaries graph are used for symbolic as well as loop-level transformations. +""" + +from collections import OrderedDict, namedtuple + +from sympy import (Eq, Indexed) + +from devito.dse.inspection import is_time_invariant, terminals +from devito.dimension import t + +__all__ = ['temporaries_graph'] + + +class Temporary(Eq): + + """ + A special :class:`sympy.Eq` which keeps track of: :: + + - :class:`sympy.Eq` writing to ``self`` + - :class:`sympy.Eq` reading from ``self`` + + A :class:`Temporary` is used as node in a temporaries graph. + """ + + def __new__(cls, lhs, rhs, **kwargs): + reads = kwargs.pop('reads', []) + readby = kwargs.pop('readby', []) + time_invariant = kwargs.pop('time_invariant', False) + scope = kwargs.pop('scope', 0) + obj = super(Temporary, cls).__new__(cls, lhs, rhs, **kwargs) + obj._reads = set(reads) + obj._readby = set(readby) + obj._is_time_invariant = time_invariant + obj._scope = scope + return obj + + @property + def reads(self): + return self._reads + + @property + def readby(self): + return self._readby + + @property + def is_time_invariant(self): + return self._is_time_invariant + + @property + def is_terminal(self): + return len(self.readby) == 0 + + @property + def is_tensor(self): + return isinstance(self.lhs, Indexed) and self.lhs.rank > 0 + + @property + def is_scalarizable(self): + return not self.is_terminal and self.is_tensor + + @property + def scope(self): + return self._scope + + def construct(self, rule): + """ + Create a new temporary starting from ``self`` replacing symbols in + the equation as specified by the dictionary ``rule``. + """ + reads = set(self.reads) - set(rule.keys()) | set(rule.values()) + rhs = self.rhs.xreplace(rule) + return Temporary(self.lhs, rhs, reads=reads, readby=self.readby, + time_invariant=self.is_time_invariant, scope=self.scope) + + def __repr__(self): + return "DSE(%s, reads=%s, readby=%s)" % (super(Temporary, self).__repr__(), + str(self.reads), str(self.readby)) + + +class TemporariesGraph(OrderedDict): + + """ + A temporaries graph built on top of an OrderedDict. + """ + + def space_dimensions(self): + for v in self.values(): + if v.is_terminal: + found = v.lhs.free_symbols - {t, v.lhs.base.label} + return tuple(sorted(found, key=lambda i: v.lhs.indices.index(i))) + return () + + +class Trace(OrderedDict): + + """ + Assign a depth level to each temporary in a temporary graph. + """ + + def __init__(self, root, graph, *args, **kwargs): + super(Trace, self).__init__(*args, **kwargs) + self._root = root + self._compute(graph) + + def _compute(self, graph): + if self.root not in graph: + return + to_visit = [(graph[self.root], 0)] + while to_visit: + temporary, level = to_visit.pop(0) + self.__setitem__(temporary.lhs, level) + to_visit.extend([(graph[i], level + 1) for i in temporary.reads]) + + @property + def root(self): + return self._root + + @property + def length(self): + return len(self) + + def intersect(self, other): + return Trace(self.root, {}, [(k, v) for k, v in self.items() if k in other]) + + def union(self, other): + return Trace(self.root, {}, [(k, v) for k, v in self.items() + other.items()]) + + +def temporaries_graph(temporaries, scope=0): + """ + Create a temporaries graph given a list of :class:`sympy.Eq`. + """ + + mapper = OrderedDict() + Node = namedtuple('Node', ['rhs', 'reads', 'readby', 'time_invariant']) + + for lhs, rhs in [i.args for i in temporaries]: + reads = {i for i in terminals(rhs) if i in mapper} + mapper[lhs] = Node(rhs, reads, set(), is_time_invariant(rhs, mapper)) + for i in mapper[lhs].reads: + assert i in mapper, "Illegal Flow" + mapper[i].readby.add(lhs) + + nodes = [Temporary(k, v.rhs, reads=v.reads, readby=v.readby, + time_invariant=v.time_invariant, scope=scope) + for k, v in mapper.items()] + + return TemporariesGraph([(i.lhs, i) for i in nodes]) diff --git a/devito/dse/inspection.py b/devito/dse/inspection.py new file mode 100644 index 0000000000..aaecc032f0 --- /dev/null +++ b/devito/dse/inspection.py @@ -0,0 +1,377 @@ +import numpy as np +from sympy import (Indexed, Function, Number, Symbol, + count_ops, lambdify, preorder_traversal, sin, cos) + +from devito.dimension import t +from devito.interfaces import SymbolicData +from devito.logger import warning +from devito.tools import flatten + +__all__ = ['indexify', 'retrieve_dimensions', 'retrieve_dtype', 'retrieve_symbols', + 'retrieve_shape', 'terminals', 'tolambda'] + + +def terminals(expr, discard_indexed=False): + """ + Return all Indexed and Symbols in a SymPy expression. + """ + + indexed = retrieve_indexed(expr) + + # Use '.name' for quickly checking uniqueness + junk = flatten([i.free_symbols for i in indexed]) + junk = [i.name for i in junk] + + symbols = {i for i in expr.free_symbols if i.name not in junk} + + if discard_indexed: + return symbols + else: + indexed.update(symbols) + return indexed + + +def collect_aliases(exprs): + """ + Determine all expressions in ``exprs`` that alias to the same expression. + + An expression A aliases an expression B if both A and B apply the same + operations to the same input operands, with the possibility for + :class:`Indexed` to index into locations at a fixed constant offset in + each dimension. + + For example: :: + + exprs = (a[i+1] + b[i+1], a[i+1] + b[j+1], a[i] + c[i], + a[i+2] - b[i+2], a[i+2] + b[i], a[i-1] + b[i-1]) + + The following expressions in ``exprs`` alias to ``a[i] + b[i]``: :: + + ``(a[i+1] + b[i+1], a[i-1] + b[i-1])`` + + Whereas the following do not: :: + + ``a[i+1] + b[j+1]``: because at least one index differs + ``a[i] + c[i]``: because at least one of the operands differs + ``a[i+2] - b[i+2]``: because at least one operation differs + ``a[i+2] + b[i]``: because there are two offsets (+2 and +0) + """ + + def check_ofs(e): + return len(set([i.indices for i in retrieve_indexed(e)])) <= 1 + + def compare_ops(e1, e2): + if type(e1) == type(e2) and len(e1.args) == len(e2.args): + if e1.is_Atom: + return True if e1 == e2 else False + elif isinstance(e1, Indexed) and isinstance(e2, Indexed): + return True if e1.base == e2.base else False + else: + for a1, a2 in zip(e1.args, e2.args): + if not compare_ops(a1, a2): + return False + return True + else: + return False + + def compare(e1, e2): + return compare_ops(e1, e2) and check_ofs(e1) and check_ofs(e2) + + found = {} + clusters = [] + unseen = list(exprs) + while unseen: + handle = unseen[0] + alias = [] + for e in list(unseen): + if compare(handle, e): + alias.append(e) + unseen.remove(e) + if alias: + cluster = tuple(alias) + for e in alias: + found[e] = cluster + clusters.append(cluster) + else: + unseen.remove(handle) + found[handle] = () + + return found, clusters + + +def estimate_cost(handle, estimate_external_functions=False): + """Estimate the operation count of ``handle``. + + :param handle: a SymPy expression or an iterator of SymPy expressions. + :param estimate_external_functions: approximate the operation count of known + functions (eg, sin, cos). + """ + internal_ops = {'trigonometry': 50} + try: + # Is it a plain SymPy object ? + iter(handle) + except TypeError: + handle = [handle] + try: + # Is it a dict ? + handle = handle.values() + except AttributeError: + try: + # Must be a list of dicts then + handle = flatten([i.values() for i in handle]) + except AttributeError: + pass + try: + # At this point it must be a list of SymPy objects + # We don't count non floating point operations + handle = [i.rhs if i.is_Equality else i for i in handle] + total_ops = count_ops(handle) + non_flops = sum(count_ops(retrieve_indexed(i, mode='all')) for i in handle) + if estimate_external_functions: + costly_ops = [retrieve_trigonometry(i) for i in handle] + total_ops += sum([internal_ops['trigonometry']*len(i) for i in costly_ops]) + return total_ops - non_flops + except: + warning("Cannot estimate cost of %s" % str(handle)) + + +def estimate_memory(handle, mode='realistic'): + """Estimate the number of memory reads and writes. + + :param handle: a SymPy expression or an iterator of SymPy expressions. + :param mode: There are multiple ways of computing the estimate: :: + + * ideal: also known as "compulsory traffic", which is the minimum + number of read/writes to be performed (ie, models an infinite cache). + * ideal_with_stores: like ideal, but a data item which is both read + and written is counted twice (ie both load and store are counted). + * realistic: assume that all datasets, even the time-independent ones, + need to be re-read at each time iteration. + """ + assert mode in ['ideal', 'ideal_with_stores', 'realistic'] + + def access(symbol): + assert isinstance(symbol, Indexed) + # Irregular accesses (eg A[B[i]]) are counted as compulsory traffic + if any(i.atoms(Indexed) for i in symbol.indices): + return symbol + else: + return symbol.base + + try: + # Is it a plain SymPy object ? + iter(handle) + except TypeError: + handle = [handle] + + if mode in ['ideal', 'ideal_with_stores']: + filter = lambda s: t in s.atoms() + else: + filter = lambda s: s + reads = set(flatten([retrieve_indexed(e.rhs) for e in handle])) + writes = set(flatten([retrieve_indexed(e.lhs) for e in handle])) + reads = set([access(s) for s in reads if filter(s)]) + writes = set([access(s) for s in writes if filter(s)]) + if mode == 'ideal': + return len(set(reads) | set(writes)) + else: + return len(reads) + len(writes) + + +def retrieve_dimensions(expr): + """ + Collect all function dimensions used in a sympy expression. + """ + dimensions = [] + + for e in preorder_traversal(expr): + if isinstance(e, SymbolicData): + dimensions += [i for i in e.indices if i not in dimensions] + + return dimensions + + +def retrieve_symbols(expr): + """ + Collect defined and undefined symbols used in a sympy expression. + + Defined symbols are functions that have an associated :class + SymbolicData: object, or dimensions that are known to the devito + engine. Undefined symbols are generic `sympy.Function` or + `sympy.Symbol` objects that need to be substituted before generating + operator C code. + """ + defined = set() + undefined = set() + + for e in preorder_traversal(expr): + if isinstance(e, SymbolicData): + defined.add(e.func(*e.indices)) + elif isinstance(e, Function): + undefined.add(e) + elif isinstance(e, Symbol): + undefined.add(e) + + return list(defined), list(undefined) + + +def retrieve_dtype(expr): + """ + Try to infer the data type of an expression. + """ + dtypes = [e.dtype for e in preorder_traversal(expr) if hasattr(e, 'dtype')] + return np.find_common_type(dtypes, []) + + +def retrieve_shape(expr): + indexed = set([e for e in preorder_traversal(expr) if isinstance(e, Indexed)]) + if not indexed: + return () + indexed = sorted(indexed, key=lambda s: len(s.indices), reverse=True) + indices = [flatten([j.free_symbols for j in i.indices]) for i in indexed] + assert all(set(indices[0]).issuperset(set(i)) for i in indices) + return tuple(indices[0]) + + +def retrieve(expr, query, mode): + """ + Find objects in an expression. This is much quicker than the more general + SymPy's find. + + :param expr: The searched expression + :param query: Search query (accepted: 'indexed', 'trigonometry') + :param mode: either 'unique' or 'all' (catch all instances) + """ + + class Set(set): + + @staticmethod + def wrap(obj): + return {obj} + + class List(list): + + @staticmethod + def wrap(obj): + return [obj] + + def update(self, obj): + return self.extend(obj) + + rules = { + 'indexed': lambda e: isinstance(e, Indexed), + 'trigonometry': lambda e: e.is_Function and e.func in [sin, cos] + } + modes = { + 'unique': Set, + 'all': List + } + assert mode in modes + collection = modes[mode] + assert query in rules, "Unknown query" + rule = rules[query] + + def run(expr): + if rule(expr): + return collection.wrap(expr) + else: + found = collection() + for a in expr.args: + found.update(run(a)) + return found + + return run(expr) + + +def retrieve_indexed(expr, mode='unique'): + """ + Shorthand for ``retrieve(expr, 'indexed', 'unique')``. + """ + return retrieve(expr, 'indexed', mode) + + +def retrieve_trigonometry(expr, mode='unique'): + """ + Shorthand for ``retrieve(expr, 'trigonometry', 'unique')``. + """ + return retrieve(expr, 'trigonometry', mode) + + +def is_time_invariant(expr, graph=None): + """ + Check if expr is time invariant. A temporaries graph may be provided + to determine whether any of the symbols involved in the evaluation + of expr are time-dependent. If a symbol in expr does not appear in the + graph, then time invariance is inferred from its shape. + """ + graph = graph or {} + + if t in expr.free_symbols: + return False + elif expr in graph: + return graph[expr].is_time_invariant + + if expr.is_Equality: + to_visit = [expr.rhs] + else: + to_visit = [expr] + + while to_visit: + handle = to_visit.pop() + for i in retrieve_indexed(handle): + if t in i.free_symbols: + return False + temporaries = [i for i in handle.free_symbols if i in graph] + for i in temporaries: + to_visit.append(graph[i].rhs) + + return True + + +def is_binary_op(expr): + """ + Return True if ``expr`` is a binary operation, False otherwise. + """ + + if not (expr.is_Add or expr.is_Mul) and not len(expr.args) == 2: + return False + + return all(isinstance(a, (Number, Symbol, Indexed)) for a in expr.args) + + +def indexify(expr): + """ + Convert functions into indexed matrix accesses in sympy expression. + + :param expr: sympy function expression to be converted. + """ + replacements = {} + + for e in preorder_traversal(expr): + if hasattr(e, 'indexed'): + replacements[e] = e.indexify() + + return expr.xreplace(replacements) + + +def tolambda(exprs): + """ + Tranform an expression into a lambda. + + :param exprs: an expression or a list of expressions. + """ + exprs = exprs if isinstance(exprs, list) else [exprs] + + lambdas = [] + + for expr in exprs: + terms = retrieve_indexed(expr.rhs) + term_symbols = [Symbol("i%d" % i) for i in range(len(terms))] + + # Substitute IndexedBase references to simple variables + # lambdify doesn't support IndexedBase references in expressions + tolambdify = expr.rhs.subs(dict(zip(terms, term_symbols))) + lambdified = lambdify(term_symbols, tolambdify) + lambdas.append((lambdified, terms)) + + return lambdas diff --git a/devito/dse/manipulation.py b/devito/dse/manipulation.py new file mode 100644 index 0000000000..553ba5288b --- /dev/null +++ b/devito/dse/manipulation.py @@ -0,0 +1,76 @@ +""" +Routines to construct new SymPy expressions transforming the provided input. +""" + +from sympy import Indexed, S + +from devito.dse.extended_sympy import Add, Mul + + +def unevaluate_arithmetic(expr): + """ + Reconstruct ``expr`` turning all :class:`sympy.Mul` and :class:`sympy.Add` + into, respectively, :class:`devito.Mul` and :class:`devito.Add`. + """ + if expr.is_Float: + return expr.func(*expr.atoms()) + elif isinstance(expr, Indexed): + return expr.func(*expr.args) + elif expr.is_Symbol: + return expr.func(expr.name) + elif expr in [S.Zero, S.One, S.NegativeOne, S.Half]: + return expr.func() + elif expr.is_Atom: + return expr.func(*expr.atoms()) + elif expr.is_Add: + rebuilt_args = [unevaluate_arithmetic(e) for e in expr.args] + return Add(*rebuilt_args, evaluate=False) + elif expr.is_Mul: + rebuilt_args = [unevaluate_arithmetic(e) for e in expr.args] + return Mul(*rebuilt_args, evaluate=False) + else: + return expr.func(*[unevaluate_arithmetic(e) for e in expr.args]) + + +def flip_indices(expr, rule): + """ + Construct a new ``expr'`` from ``expr`` such that all indices are shifted as + established by ``rule``. + + For example: :: + + (rule=(x, y)) a[i][j+2] + b[j][i] --> a[x][y] + b[x][y] + """ + + def run(expr, flipped): + if expr.is_Float: + return expr.func(*expr.atoms()) + elif isinstance(expr, Indexed): + flipped.add(expr.indices) + return Indexed(expr.base, *rule) + elif expr.is_Symbol: + return expr.func(expr.name) + elif expr in [S.Zero, S.One, S.NegativeOne, S.Half]: + return expr.func() + elif expr.is_Atom: + return expr.func(*expr.atoms()) + else: + return expr.func(*[run(e, flipped) for e in expr.args], evaluate=False) + + flipped = set() + handle = run(expr, flipped) + return handle, flipped + + +def rxreplace(exprs, mapper): + """ + Apply Sympy's xreplace recursively. + """ + + replaced = [] + for i in exprs: + old, new = i, i.xreplace(mapper) + while new != old: + old, new = new, new.xreplace(mapper) + replaced.append(new) + return replaced diff --git a/devito/dse/symbolics.py b/devito/dse/symbolics.py new file mode 100644 index 0000000000..5fa0648491 --- /dev/null +++ b/devito/dse/symbolics.py @@ -0,0 +1,503 @@ +""" +The Devito symbolic engine is built on top of SymPy and provides two +classes of functions: +- for inspection of expressions +- for (in-place) manipulation of expressions +- for creation of new objects given some expressions +All exposed functions are prefixed with 'dse' (devito symbolic engine) +""" + +from __future__ import absolute_import + +from collections import OrderedDict, Sequence +from time import time + +from sympy import (Eq, Indexed, IndexedBase, S, + collect, collect_const, cos, cse, flatten, + numbered_symbols, preorder_traversal, sin) + +from devito.dimension import t, x, y, z +from devito.logger import dse, dse_warning + +from devito.dse.extended_sympy import bhaskara_sin, bhaskara_cos +from devito.dse.graph import temporaries_graph +from devito.dse.inspection import (collect_aliases, estimate_cost, estimate_memory, + is_binary_op, is_time_invariant, terminals) +from devito.dse.manipulation import flip_indices, rxreplace, unevaluate_arithmetic + +__all__ = ['rewrite'] + +_temp_prefix = 'temp' + + +def rewrite(expr, mode='advanced'): + """ + Transform expressions to reduce their operation count. + + :param expr: the target expression. + :param mode: drive the expression transformation. Available modes are + 'basic', 'factorize', 'approx-trigonometry' and 'advanced' + (default). They act as follows: :: + + * 'noop': do nothing, but track performance metrics + * 'basic': apply common sub-expressions elimination. + * 'factorize': apply heuristic factorization of temporaries. + * 'approx-trigonometry': replace expensive trigonometric + functions with suitable polynomial approximations. + * 'glicm': apply heuristic hoisting of time-invariant terms. + * 'advanced': compose all known transformations. + """ + + if isinstance(expr, Sequence): + assert all(isinstance(e, Eq) for e in expr) + expr = list(expr) + elif isinstance(expr, Eq): + expr = [expr] + else: + raise ValueError("Got illegal expr of type %s." % type(expr)) + + if not mode: + return State(expr) + elif isinstance(mode, str): + mode = set([mode]) + else: + try: + mode = set(mode) + except TypeError: + dse_warning("Arg mode must be str or tuple (got %s)" % type(mode)) + return expr + if mode.isdisjoint({'noop', 'basic', 'factorize', 'approx-trigonometry', + 'glicm', 'advanced'}): + dse_warning("Unknown rewrite mode(s) %s" % str(mode)) + return State(expr) + else: + return Rewriter(expr).run(mode) + + +def dse_transformation(func): + + def wrapper(self, state, **kwargs): + if kwargs['mode'].intersection(set(self.triggers[func.__name__])): + tic = time() + state.update(**func(self, state)) + toc = time() + + key = '%s%d' % (func.__name__, len(self.timings)) + self.ops[key] = estimate_cost(state.exprs) + self.timings[key] = toc - tic + + return wrapper + + +class State(object): + + def __init__(self, exprs): + self.exprs = exprs + self.mapper = OrderedDict() + + def update(self, exprs=None, mapper=None): + self.exprs = exprs or self.exprs + self.mapper = mapper or self.mapper + + @property + def time_invariants(self): + return [i for i in self.exprs if i.lhs in self.mapper] + + @property + def time_varying(self): + return [i for i in self.exprs if i not in self.time_invariants] + + @property + def ops_time_invariants(self): + return estimate_cost(self.time_invariants) + + @property + def ops_time_varying(self): + return estimate_cost(self.time_varying) + + @property + def ops(self): + return self.ops_time_invariants + self.ops_time_varying + + @property + def memory_time_invariants(self): + return estimate_memory(self.time_invariants) + + @property + def memory_time_varying(self): + return estimate_memory(self.time_varying) + + @property + def memory(self): + return self.memory_time_invariants + self.memory_time_varying + + +class Rewriter(object): + + """ + Transform expressions to reduce their operation count. + """ + + triggers = { + '_cse': ('basic', 'advanced'), + '_factorize': ('factorize', 'advanced'), + '_optimize_trigonometry': ('approx-trigonometry', 'advanced'), + '_replace_time_invariants': ('glicm', 'advanced') + } + + # Aggressive transformation if the operation count is greather than this + # empirically determined threshold + threshold = 15 + + def __init__(self, exprs): + self.exprs = exprs + + self.ops = OrderedDict([('baseline', estimate_cost(exprs))]) + self.timings = OrderedDict() + + def run(self, mode): + state = State(self.exprs) + + self._cse(state, mode=mode) + self._factorize(state, mode=mode) + self._optimize_trigonometry(state, mode=mode) + self._replace_time_invariants(state, mode=mode) + self._factorize(state, mode=mode) + + self._finalize(state) + + self._summary(mode) + + return state + + @dse_transformation + def _factorize(self, state, **kwargs): + """ + Collect terms in each expr in exprs based on the following heuristic: + + * Collect all literals; + * Collect all temporaries produced by CSE; + * If the expression has an operation count higher than + self.threshold, then this is applied recursively until + no more factorization opportunities are available. + """ + + processed = [] + for expr in state.exprs: + cost_expr = estimate_cost(expr) + + handle = collect_nested(expr) + cost_handle = estimate_cost(handle) + + if cost_handle < cost_expr and cost_handle >= Rewriter.threshold: + handle_prev = handle + cost_prev = cost_expr + while cost_handle < cost_prev: + handle_prev, handle = handle, collect_nested(handle) + cost_prev, cost_handle = cost_handle, estimate_cost(handle) + cost_handle, handle = cost_prev, handle_prev + + processed.append(handle) + + return {'exprs': processed} + + @dse_transformation + def _cse(self, state, **kwargs): + """ + Perform common subexpression elimination. + """ + + temporaries, leaves = cse(state.exprs, numbered_symbols(_temp_prefix)) + for i in range(len(state.exprs)): + leaves[i] = Eq(state.exprs[i].lhs, leaves[i].rhs) + + # Restore some of the common sub-expressions that have potentially + # been collected: simple index calculations (eg, t - 1), IndexedBase, + # Indexed, binary Add, binary Mul. + revert = OrderedDict() + keep = OrderedDict() + for k, v in temporaries: + if isinstance(v, (IndexedBase, Indexed)): + revert[k] = v + elif v.is_Add and not set([t, x, y, z]).isdisjoint(set(v.args)): + revert[k] = v + elif is_binary_op(v): + revert[k] = v + else: + keep[k] = v + for k, v in revert.items(): + mapper = {} + for i in preorder_traversal(v): + if isinstance(i, Indexed): + new_indices = [] + for index in i.indices: + if index in revert: + new_indices.append(revert[index]) + else: + new_indices.append(index) + if i.base.label in revert: + mapper[i] = Indexed(revert[i.base.label], *new_indices) + if i in revert: + mapper[i] = revert[i] + revert[k] = v.xreplace(mapper) + mapper = {} + for e in leaves + list(keep.values()): + for i in preorder_traversal(e): + if isinstance(i, Indexed): + new_indices = [] + for index in i.indices: + if index in revert: + new_indices.append(revert[index]) + else: + new_indices.append(index) + if i.base.label in revert: + mapper[i] = Indexed(revert[i.base.label], *new_indices) + elif tuple(new_indices) != i.indices: + mapper[i] = Indexed(i.base, *new_indices) + if i in revert: + mapper[i] = revert[i] + leaves = rxreplace(leaves, mapper) + kept = rxreplace([Eq(k, v) for k, v in keep.items()], mapper) + + # If the RHS of a temporary variable is the LHS of a leaf, + # update the value of the temporary variable after the leaf + new_leaves = [] + for leaf in leaves: + new_leaves.append(leaf) + for i in kept: + if leaf.lhs in preorder_traversal(i.rhs): + new_leaves.append(i) + break + + # Reshuffle to make sure temporaries come later than their read values + processed = OrderedDict([(i.lhs, i) for i in kept + new_leaves]) + temporaries = set(processed.keys()) + ordered = OrderedDict() + while processed: + k, v = processed.popitem(last=False) + temporary_reads = terminals(v.rhs) & temporaries - {v.lhs} + if all(i in ordered for i in temporary_reads): + ordered[k] = v + else: + # Must wait for some earlier temporaries, push back into queue + processed[k] = v + + return {'exprs': list(ordered.values())} + + @dse_transformation + def _optimize_trigonometry(self, state, **kwargs): + """ + Rebuild ``exprs`` replacing trigonometric functions with Bhaskara + polynomials. + """ + + processed = [] + for expr in state.exprs: + handle = expr.replace(sin, bhaskara_sin) + handle = handle.replace(cos, bhaskara_cos) + processed.append(handle) + + return {'exprs': processed} + + @dse_transformation + def _replace_time_invariants(self, state, **kwargs): + """ + Create a new expr' given expr where the longest time-invariant + sub-expressions are replaced by temporaries. A mapper from the + introduced temporaries to the corresponding time-invariant + sub-expressions is also returned. + + Examples + ======== + + (a+b)*c[t] + s*d[t] + v*(d + e[t] + r) + --> (t1*c[t] + s*d[t] + v*(e[t] + t2), {t1: (a+b), t2: (d+r)}) + (a*b[t] + c*d[t])*v[t] + --> ((a*b[t] + c*d[t])*v[t], {}) + """ + + template = "ti%d" + graph = temporaries_graph(state.exprs) + space_dimensions = graph.space_dimensions() + queue = graph.copy() + + # What expressions is it worth transforming (cm=cost model)? + # Formula: ops(expr)*aliases(expr) > self.threshold <==> do it + # For more information about "aliases", check out collect_aliases.__doc__ + aliases, clusters = collect_aliases([e.rhs for e in state.exprs]) + cm = lambda e: estimate_cost(e, True)*len(aliases.get(e, [e])) > self.threshold + + # Replace time invariants + processed = [] + mapper = OrderedDict() + while queue: + k, v = queue.popitem(last=False) + + make = lambda m: Indexed(template % (len(m)+len(mapper)), *space_dimensions) + invariant = lambda e: is_time_invariant(e, graph) + handle, flag, mapped = replace_invariants(v, make, invariant, cm) + + if flag: + mapper.update(mapped) + for i in v.readby: + graph[i] = graph[i].construct({k: handle.rhs}) + else: + processed.append(Eq(v.lhs, graph[v.lhs].rhs)) + + # Squash aliases and tweak the affected indices accordingly + reducible = OrderedDict() + others = OrderedDict() + for k, v in mapper.items(): + cluster = aliases.get(v) + if cluster: + index = clusters.index(cluster) + reducible.setdefault(index, []).append(k) + else: + others[k] = v + rule = {} + reduced_mapper = OrderedDict() + for i, cluster in enumerate(reducible.values()): + for k in cluster: + v, flipped = flip_indices(mapper[k], space_dimensions) + assert len(flipped) == 1 + reduced_mapper[Indexed(template % i, *space_dimensions)] = v + rule[k] = Indexed(template % i, *flipped.pop()) + handle, processed = list(processed), [] + for e in handle: + processed.append(e.xreplace(rule)) + for k, v in others.items(): + reduced_mapper[k] = v.xreplace(rule) + + return {'exprs': processed, 'mapper': reduced_mapper} + + def _finalize(self, state): + """ + Make sure that any subsequent sympy operation applied to the expressions + in ``state.exprs`` does not alter the structure of the transformed objects. + """ + exprs = [Eq(k, v) for k, v in state.mapper.items()] + state.exprs + state.update(exprs=[unevaluate_arithmetic(e) for e in exprs]) + + def _summary(self, mode): + """ + Print a summary of the DSE transformations + """ + + if mode.intersection({'basic', 'advanced'}): + try: + # The state after CSE should be used as baseline for fairness + baseline = self.ops['_cse0'] + except KeyError: + baseline = self.ops['baseline'] + self.ops.pop('baseline') + steps = " --> ".join("(%s) %d" % (filter(lambda c: not c.isdigit(), k), v) + for k, v in self.ops.items()) + try: + gain = float(baseline) / list(self.ops.values())[-1] + summary = " %s flops; gain: %.2f X" % (steps, gain) + except ZeroDivisionError: + summary = "" + elapsed = sum(self.timings.values()) + dse("Rewriter:%s [%.2f s]" % (summary, elapsed)) + + +def collect_nested(expr): + """ + Collect terms appearing in expr, checking all levels of the expression tree. + + :param expr: the expression to be factorized. + """ + + def run(expr): + # Return semantic (rebuilt expression, factorization candidates) + + if expr.is_Float: + return expr.func(*expr.atoms()), [expr] + elif isinstance(expr, Indexed): + return expr.func(*expr.args), [] + elif expr.is_Symbol: + return expr.func(expr.name), [expr] + elif expr in [S.Zero, S.One, S.NegativeOne, S.Half]: + return expr.func(), [expr] + elif expr.is_Atom: + return expr.func(*expr.atoms()), [] + elif expr.is_Add: + rebuilt, candidates = zip(*[run(arg) for arg in expr.args]) + + w_numbers = [i for i in rebuilt if any(j.is_Number for j in i.args)] + wo_numbers = [i for i in rebuilt if i not in w_numbers] + + w_numbers = collect_const(expr.func(*w_numbers)) + wo_numbers = expr.func(*wo_numbers) + + if wo_numbers: + for i in flatten(candidates): + wo_numbers = collect(wo_numbers, i) + + rebuilt = expr.func(w_numbers, wo_numbers) + return rebuilt, [] + elif expr.is_Mul: + rebuilt, candidates = zip(*[run(arg) for arg in expr.args]) + rebuilt = collect_const(expr.func(*rebuilt)) + return rebuilt, flatten(candidates) + else: + rebuilt, candidates = zip(*[run(arg) for arg in expr.args]) + return expr.func(*rebuilt), flatten(candidates) + + return run(expr)[0] + + +def replace_invariants(expr, make, invariant=lambda e: e, cm=lambda e: True): + """ + Replace all sub-expressions of ``expr`` such that ``invariant(expr) == True`` + with a temporary created through ``make(expr)``. A sub-expression ``e`` + within ``expr`` is not visited if ``cm(e) == False``. + """ + + def run(expr, root, mapper): + # Return semantic: (rebuilt expr, True <==> invariant) + + if expr.is_Float: + return expr.func(*expr.atoms()), True + elif expr in [S.Zero, S.One, S.NegativeOne, S.Half]: + return expr.func(), True + elif expr.is_Symbol: + return expr.func(expr.name), invariant(expr) + elif expr.is_Atom: + return expr.func(*expr.atoms()), True + elif isinstance(expr, Indexed): + return expr.func(*expr.args), invariant(expr) + elif expr.is_Equality: + handle, flag = run(expr.rhs, expr.rhs, mapper) + return expr.func(expr.lhs, handle, evaluate=False), flag + else: + children = [run(a, root, mapper) for a in expr.args] + invs = [a for a, flag in children if flag] + varying = [a for a, _ in children if a not in invs] + if not invs: + # Nothing is time-invariant + return (expr.func(*varying, evaluate=False), False) + elif len(invs) == len(children): + # Everything is time-invariant + if expr == root: + if cm(expr): + temporary = make(mapper) + mapper[temporary] = expr.func(*invs, evaluate=False) + return temporary, True + else: + return expr.func(*invs, evaluate=False), False + else: + # Go look for longer expressions first + return expr.func(*invs, evaluate=False), True + else: + # Some children are time-invariant, but expr is time-dependent + if cm(expr) and len(invs) > 1: + temporary = make(mapper) + mapper[temporary] = expr.func(*invs, evaluate=False) + return expr.func(*(varying + [temporary]), evaluate=False), False + else: + return expr.func(*(varying + invs), evaluate=False), False + + mapper = OrderedDict() + handle, flag = run(expr, expr, mapper) + return handle, flag, mapper diff --git a/devito/expression.py b/devito/expression.py index 005c9769b4..8216b573da 100644 --- a/devito/expression.py +++ b/devito/expression.py @@ -8,8 +8,8 @@ from devito.codeprinter import ccode from devito.dimension import Dimension +from devito.dse.inspection import indexify, terminals from devito.interfaces import SymbolicData -from devito.symbolics import dse_indexify, terminals from devito.tools import filter_ordered __all__ = ['Expression'] @@ -62,7 +62,7 @@ def signature(self): def indexify(self): """Convert stencil expression to "indexed" format""" - self.stencil = dse_indexify(self.stencil) + self.stencil = indexify(self.stencil) @property def index_offsets(self): diff --git a/devito/function_manager.py b/devito/function_manager.py index 6187b85fb0..629ba0ffa1 100644 --- a/devito/function_manager.py +++ b/devito/function_manager.py @@ -13,7 +13,8 @@ class FunctionManager(object): :param openmp: True if using OpenMP. Default is False """ libraries = ['assert.h', 'stdlib.h', 'math.h', - 'stdio.h', 'string.h', 'sys/time.h'] + 'stdio.h', 'string.h', 'sys/time.h', + 'xmmintrin.h', 'pmmintrin.h'] _pymic_attribute = 'PYMIC_KERNEL' @@ -136,7 +137,7 @@ def generate_function_body(self, function_descriptor): '*%s' % (param[1]+"_pointer")) statements.append(cast_pointer) - statements.append(function_descriptor.body) + statements.extend(function_descriptor.body) statements.append(cgen.Statement("return 0")) return cgen.Block(statements) diff --git a/devito/interfaces.py b/devito/interfaces.py index f176879274..45f864e392 100644 --- a/devito/interfaces.py +++ b/devito/interfaces.py @@ -73,6 +73,11 @@ def __init__(self, \*args, \*\*kwargs): to (re-)create the dimension arguments of the symbolic function. """ + is_DenseData = False + is_TimeData = False + is_Coordinates = False + is_PointData = False + def __new__(cls, *args, **kwargs): if cls in _SymbolCache: newobj = Function.__new__(cls, *args) @@ -115,6 +120,9 @@ class DenseData(SymbolicData): therefore do not support time derivatives. Use :class:`TimeData` for time-varying grid data. """ + + is_DenseData = True + def __init__(self, *args, **kwargs): if not self._cached(): self.name = kwargs.get('name') @@ -371,6 +379,8 @@ class TimeData(DenseData): and whether we want to write intermediate timesteps in the buffer. """ + is_TimeData = True + def __init__(self, *args, **kwargs): if not self._cached(): super(TimeData, self).__init__(*args, **kwargs) @@ -462,9 +472,11 @@ def dt2(self): class CoordinateData(SymbolicData): - """Data object for sparse coordinate data that acts as a Function symbol - """ + Data object for sparse coordinate data that acts as a Function symbol + """ + + is_Coordinates = True def __init__(self, *args, **kwargs): if not self._cached(): @@ -501,7 +513,8 @@ def indexed(self): class PointData(DenseData): - """Data object for sparse point data that acts as a Function symbol + """ + Data object for sparse point data that acts as a Function symbol :param name: Name of the resulting :class:`sympy.Function` symbol :param npoint: Number of points to sample @@ -514,6 +527,8 @@ class PointData(DenseData): symbolic behaviour follows the use in the current problem. """ + is_PointData = True + def __init__(self, *args, **kwargs): if not self._cached(): self.nt = kwargs.get('nt') diff --git a/devito/logger.py b/devito/logger.py index 2197ba2733..a70d401886 100644 --- a/devito/logger.py +++ b/devito/logger.py @@ -4,9 +4,9 @@ import sys __all__ = ('set_log_level', 'set_log_noperf', 'log', - 'DEBUG', 'INFO', 'AUTOTUNER', 'PERF_OK', 'PERF_WARN', - 'WARNING', 'ERROR', 'CRITICAL', - 'log', 'warning', 'error', 'info_at', 'perfok', 'perfbad', + 'DEBUG', 'INFO', 'AUTOTUNER', 'DSE', 'DSE_WARN', 'WARNING', + 'ERROR', 'CRITICAL', + 'log', 'warning', 'error', 'info_at', 'dse', 'dse_warning', 'RED', 'GREEN', 'BLUE') @@ -18,15 +18,15 @@ DEBUG = logging.DEBUG INFO = logging.INFO AUTOTUNER = 27 -PERF_OK = 28 -PERF_WARN = 29 +DSE = 28 +DSE_WARN = 29 WARNING = logging.WARNING ERROR = logging.ERROR CRITICAL = logging.CRITICAL logging.addLevelName(AUTOTUNER, "AUTOTUNER") -logging.addLevelName(PERF_OK, "PERF_OK") -logging.addLevelName(PERF_WARN, "PERF_WARN") +logging.addLevelName(DSE, "DSE") +logging.addLevelName(DSE_WARN, "DSE_WARN") logger.setLevel(INFO) @@ -39,8 +39,8 @@ DEBUG: RED, INFO: NOCOLOR, AUTOTUNER: GREEN, - PERF_OK: GREEN, - PERF_WARN: BLUE, + DSE: NOCOLOR, + DSE_WARN: BLUE, WARNING: BLUE, ERROR: RED, CRITICAL: RED @@ -51,7 +51,7 @@ def set_log_level(level): """ Set the log level of the Devito logger. - :param level: accepted values are: DEBUG, INFO, AUTOTUNER, PERF_OK, PERF_WARN, + :param level: accepted values are: DEBUG, INFO, AUTOTUNER, DSE, DSE_WARN, WARNING, ERROR, CRITICAL """ logger.setLevel(level) @@ -68,10 +68,10 @@ def log(msg, level=INFO, *args, **kwargs): the severity 'level'. :param msg: the message to be printed. - :param level: accepted values are: DEBUG, INFO, AUTOTUNER, PERF_OK, PERF_WARN, + :param level: accepted values are: DEBUG, INFO, AUTOTUNER, DSE, DSE_WARN, WARNING, ERROR, CRITICAL """ - assert level in [DEBUG, INFO, AUTOTUNER, PERF_OK, PERF_WARN, + assert level in [DEBUG, INFO, AUTOTUNER, DSE, DSE_WARN, WARNING, ERROR, CRITICAL] color = COLORS[level] if sys.stdout.isatty() and sys.stderr.isatty() else '%s' @@ -98,9 +98,9 @@ def debug(msg, *args, **kwargs): log(msg, DEBUG, *args, **kwargs) -def perfok(msg, *args, **kwargs): - log(msg, PERF_OK, *args, **kwargs) +def dse(msg, *args, **kwargs): + log("DSE: %s" % msg, DSE, *args, **kwargs) -def perfbad(msg, *args, **kwargs): - log(msg, PERF_WARN, *args, **kwargs) +def dse_warning(msg, *args, **kwargs): + log(msg, DSE_WARN, *args, **kwargs) diff --git a/devito/operator.py b/devito/operator.py index e5badf8bcc..9404fff33c 100644 --- a/devito/operator.py +++ b/devito/operator.py @@ -5,10 +5,10 @@ from devito.compiler import get_compiler_from_env from devito.dimension import t, x, y, z +from devito.dse.inspection import (indexify, retrieve_dimensions, + retrieve_symbols, tolambda) from devito.interfaces import TimeData from devito.propagator import Propagator -from devito.symbolics import (dse_dimensions, dse_indexify, dse_rewrite, - dse_symbols, dse_tolambda) __all__ = ['Operator'] @@ -44,14 +44,12 @@ class Operator(object): tuned block sizes :param input_params: List of symbols that are expected as input. :param output_params: List of symbols that define operator output. - :param factorized: A map given by {string_name:sympy_object} for including factorized - terms """ def __init__(self, nt, shape, dtype=np.float32, stencils=[], subs=[], spc_border=0, time_order=0, forward=True, compiler=None, profile=False, dse='advanced', cache_blocking=None, input_params=None, - output_params=None, factorized={}): + output_params=None): # Derive JIT compilation infrastructure self.compiler = compiler or get_compiler_from_env() @@ -65,13 +63,13 @@ def __init__(self, nt, shape, dtype=np.float32, stencils=[], sym_undef = set() for eqn in self.stencils: - lhs_def, lhs_undef = dse_symbols(eqn.lhs) + lhs_def, lhs_undef = retrieve_symbols(eqn.lhs) sym_undef.update(lhs_undef) if self.output_params is None: self.output_params = list(lhs_def) - rhs_def, rhs_undef = dse_symbols(eqn.rhs) + rhs_def, rhs_undef = retrieve_symbols(eqn.rhs) sym_undef.update(rhs_undef) if self.input_params is None: @@ -80,8 +78,8 @@ def __init__(self, nt, shape, dtype=np.float32, stencils=[], # Pull all dimension indices from the incoming stencil dimensions = [] for eqn in self.stencils: - dimensions += [i for i in dse_dimensions(eqn.lhs) if i not in dimensions] - dimensions += [i for i in dse_dimensions(eqn.rhs) if i not in dimensions] + dimensions += [i for i in retrieve_dimensions(eqn.lhs) if i not in dimensions] + dimensions += [i for i in retrieve_dimensions(eqn.rhs) if i not in dimensions] # Time dimension is fixed for now time_dim = t @@ -110,23 +108,17 @@ def __init__(self, nt, shape, dtype=np.float32, stencils=[], for eqn in self.stencils] # Convert incoming stencil equations to "indexed access" format - self.stencils = [Eq(dse_indexify(eqn.lhs), dse_indexify(eqn.rhs)) + self.stencils = [Eq(indexify(eqn.lhs), indexify(eqn.rhs)) for eqn in self.stencils] - for name, value in factorized.items(): - factorized[name] = dse_indexify(value) - - # Applies CSE - self.stencils = dse_rewrite(self.stencils, mode=dse) - # Apply user-defined subs to stencil self.stencils = [eqn.subs(subs[0]) for eqn in self.stencils] - self.propagator = Propagator(self.getName(), nt, shape, self.stencils, - factorized=factorized, dtype=dtype, - spc_border=spc_border, time_order=time_order, - forward=forward, space_dims=self.space_dims, - compiler=self.compiler, profile=profile, - cache_blocking=cache_blocking) + + self.propagator = Propagator(self.getName(), nt, shape, self.stencils, dse=dse, + dtype=dtype, spc_border=spc_border, + time_order=time_order, forward=forward, + space_dims=self.space_dims, compiler=self.compiler, + profile=profile, cache_blocking=cache_blocking) self.dtype = dtype self.nt = nt self.shape = shape @@ -137,15 +129,6 @@ def __init__(self, nt, shape, dtype=np.float32, stencils=[], for param in self.signature: self.propagator.add_devito_param(param) self.symbol_to_data[param.name] = param - self.propagator.stencils = self.stencils - self.propagator.factorized = factorized - for name, val in factorized.items(): - if forward: - self.propagator.factorized[name] = \ - dse_indexify(val.subs(t, t - 1)).subs(subs[1]) - else: - self.propagator.factorized[name] = \ - dse_indexify(val.subs(t, t + 1)).subs(subs[1]) @property def signature(self): @@ -200,9 +183,9 @@ def run_python(self): Execute the operator using Python """ time_loop_limits = self.propagator.time_loop_limits - time_loop_lambdas_b = dse_tolambda(self.propagator.time_loop_stencils_b) - time_loop_lambdas_a = dse_tolambda(self.propagator.time_loop_stencils_a) - stencil_lambdas = dse_tolambda(self.stencils) + time_loop_lambdas_b = tolambda(self.propagator.time_loop_stencils_b) + time_loop_lambdas_a = tolambda(self.propagator.time_loop_stencils_a) + stencil_lambdas = tolambda(self.stencils) for ti in range(*time_loop_limits): # Run time loop stencils before space loop diff --git a/devito/profiler.py b/devito/profiler.py index f9fe6eb331..0ef5895dd2 100644 --- a/devito/profiler.py +++ b/devito/profiler.py @@ -2,7 +2,7 @@ import numpy as np -from devito.cgen_wrapper import Assign, Block, For, Statement, Struct, Value +from devito.cgen_wrapper import Block, Statement, Struct, Value class Profiler(object): @@ -22,8 +22,6 @@ def __init__(self, openmp=False, dtype=np.float32): self._C_timings = None - self.num_flops = {} - def add_profiling(self, code, name, omp_flag=None, to_ignore=None): """Function to add profiling code to the given :class:`cgen.Block`. @@ -48,9 +46,6 @@ def add_profiling(self, code, name, omp_flag=None, to_ignore=None): self.t_fields.append((name, c_double)) - self.num_flops[name] = FlopsCounter(code, name, self.openmp, - self.float_size, to_ignore or []).run() - init = [ Statement("struct timeval start_%s, end_%s" % (name, name)) ] + omp_flag + [Statement("gettimeofday(&start_%s, NULL)" % name)] @@ -124,136 +119,3 @@ def timings(self): for field, _ in self._C_timings._fields_} else: return {} - - @property - def gflops(self): - """GFlops per loop iteration, keyed by code section.""" - return self.num_flops - - -class FlopsCounter(object): - - """Compute the operational intensity of a stencil.""" - - def __init__(self, code, name, openmp, float_size, to_ignore): - self.code = code - self.name = name - self.openmp = openmp - self.float_size = float_size - - self.to_ignore = [ - "int", - "float", - "double", - "F", - "e", - "fabsf", - "powf", - "floor", - "ceil", - "temp", - "i", - "t", - "p", # This one shouldn't be here. - # It should be passed in by an Iteration object. - # Added only because tti_example uses it. - ] + to_ignore - self.seen = set() - - def run(self): - """ - Calculates the total operation intensity of the code provided. - If needed, lets the C code calculate it. - """ - num_flops = 0 - - for elem in self.code: - if isinstance(elem, Assign): - num_flops += self._handle_assign(elem) - elif isinstance(elem, For): - num_flops += self._handle_for(elem) - elif isinstance(elem, Block): - num_flops += self._handle_block(elem)[0] - else: - # no op - pass - - return num_flops - - def _handle_for(self, loop): - loop_flops = 0 - loop_oi_f = 0 - - if isinstance(loop.body, Assign): - loop_flops = self._handle_assign(loop.body) - loop_oi_f = loop_flops - elif isinstance(loop.body, Block): - loop_oi_f, loop_flops = self._handle_block(loop.body) - elif isinstance(loop.body, For): - loop_oi_f = self._handle_for(loop.body) - else: - # no op - pass - - old_body = loop.body - - while isinstance(old_body, Block) and isinstance(old_body.contents[0], Statement): - old_body = old_body.contents[1] - - if loop_flops == 0: - if old_body in self.seen: - return 0 - - self.seen.add(old_body) - - return loop_oi_f - - if old_body in self.seen: - return 0 - - self.seen.add(old_body) - return loop_oi_f - - def _handle_block(self, block): - block_flops = 0 - block_oi = 0 - - for elem in block.contents: - if isinstance(elem, Assign): - a_flops = self._handle_assign(elem) - block_flops += a_flops - block_oi += a_flops - elif isinstance(elem, Block): - nblock_oi, nblock_flops = self._handle_block(elem) - block_oi += nblock_oi - block_flops += nblock_flops - elif isinstance(elem, For): - block_oi += self._handle_for(elem) - else: - # no op - pass - - return block_oi, block_flops - - def _handle_assign(self, assign): - flops = 0 - - # removing casting statements and function calls to floor - # that can confuse the parser - string = assign.lvalue + " " + assign.rvalue - - brackets = 0 - for idx in range(len(string)): - c = string[idx] - - # We skip index operations. The third check works because in the - # generated code constants always precede variables in operations - # and is needed because Sympy prints fractions like this: 1.0F/4.0F - if brackets == 0 and c in "*/-+" and not string[idx+1].isdigit(): - flops += 1 - elif c == "[": - brackets += 1 - elif c == "]": - brackets -= 1 - - return flops diff --git a/devito/propagator.py b/devito/propagator.py index 953e46278d..8b94efe467 100644 --- a/devito/propagator.py +++ b/devito/propagator.py @@ -1,7 +1,7 @@ from __future__ import absolute_import import operator -from collections import Iterable, defaultdict +from collections import Iterable, OrderedDict, defaultdict from functools import reduce from hashlib import sha1 from os import path @@ -16,12 +16,13 @@ from devito.compiler import (IntelMICCompiler, get_compiler_from_env, get_tmp_dir, jit_compile_and_load) from devito.dimension import t, x, y, z +from devito.dse.inspection import retrieve_dtype +from devito.dse.symbolics import _temp_prefix, rewrite from devito.expression import Expression from devito.function_manager import FunctionDescriptor, FunctionManager from devito.iteration import Iteration from devito.logger import info from devito.profiler import Profiler -from devito.symbolics import dse_dtype from devito.tools import flatten @@ -41,8 +42,6 @@ class Propagator(object): :param nt: Number of timesteps to execute :param shape: Shape of the data buffer over which to execute :param stencils: List of :class:`sympy.Eq` used to create the kernel - :param factorized: A map given by {string_name:sympy_object} for including - factorized terms :param spc_border: Number of spatial padding layers :param time_order: Order of the time discretisation :param time_dim: Symbol that defines the time dimension @@ -52,6 +51,8 @@ class Propagator(object): If not provided, the compiler will be inferred from the environment variable DEVITO_ARCH, or default to GNUCompiler :param profile: Flag to enable performance profiling + :param dse: Set of transformations applied by the Devito Symbolic Engine. + Available: [None, 'basic', 'advanced' (default)] :param cache_blocking: Block sizes used for cache clocking. Can be either a single number used for all dimensions except inner most or a list explicitly stating block sizes for each dimension @@ -60,21 +61,19 @@ class Propagator(object): tune block sizes """ - def __init__(self, name, nt, shape, stencils, factorized=None, spc_border=0, - time_order=0, time_dim=None, space_dims=None, dtype=np.float32, - forward=True, compiler=None, profile=False, cache_blocking=None): + def __init__(self, name, nt, shape, stencils, dse=None, spc_border=0, time_order=0, + time_dim=None, space_dims=None, dtype=np.float32, forward=True, + compiler=None, profile=False, cache_blocking=None): self.stencils = stencils self.dtype = dtype - self.factorized = factorized or {} self.time_order = time_order self.spc_border = spc_border self.loop_body = None # Default time and space symbols if not provided self.time_dim = time_dim or t - if space_dims is not None: - self.space_dims = space_dims - else: - self.space_dims = (x, z) if len(shape) == 2 else (x, y, z)[:len(shape)] + if not space_dims: + space_dims = (x, z) if len(shape) == 2 else (x, y, z)[:len(shape)] + self.space_dims = tuple(space_dims) self.shape = shape # Internal flags and meta-data @@ -83,6 +82,7 @@ def __init__(self, name, nt, shape, stencils, factorized=None, spc_border=0, self.time_steppers = [] self.time_order = time_order self.nt = nt + self.time_invariants = [] self.time_loop_stencils_b = [] self.time_loop_stencils_a = [] @@ -115,6 +115,11 @@ def __init__(self, name, nt, shape, stencils, factorized=None, spc_border=0, # Profiler needs to know whether openmp is set self.profiler = Profiler(self.compiler.openmp, self.dtype) + # Performance data + self._dse = dse + self._ops = {} + self._memory = {} + # Cache blocking and block sizes self.cache_blocking = cache_blocking self.block_sizes = [] @@ -185,8 +190,12 @@ def ccode(self): if self._ccode is None: manager = FunctionManager([self.fd], mic_flag=self.mic_flag, openmp=self.compiler.openmp) - # For some reason we need this call to trigger fd.body - self.get_fd() + + # Go through the Devito Symbolic Engine to generate optimized code + self._optimize() + + # Perform code generation + self._get_fd() if self.profile: manager.add_struct_definition(self.profiler.as_cgen_struct(Profiler.TIME)) @@ -223,20 +232,18 @@ def timings(self): def oi(self): """Summary of operational intensities, by code section.""" - gflops_per_section = self.gflops - bytes_per_section = self.traffic() oi_per_section = {} for i, subsection in enumerate(self.time_loop_stencils_b): key = "%s%d" % (PRE_STENCILS.name, i) - oi_per_section[key] = 1.0*gflops_per_section[key]/bytes_per_section[key] + oi_per_section[key] = 1.0*self.gflops[key]/self.traffic[key] key = LOOP_BODY.name - oi_per_section[key] = 1.0*gflops_per_section[key]/bytes_per_section[key] + oi_per_section[key] = 1.0*self.gflops[key]/self.traffic[key] for i, subsection in enumerate(self.time_loop_stencils_a): key = "%s%d" % (POST_STENCILS.name, i) - oi_per_section[key] = 1.0*gflops_per_section[key]/bytes_per_section[key] + oi_per_section[key] = 1.0*self.gflops[key]/self.traffic[key] return oi_per_section @@ -253,9 +260,10 @@ def niters(self): niters = subsection.limits[1] if isinstance(subsection, Iteration) else 1 niters_per_section[key] = with_time_loop(niters) - key = LOOP_BODY.name - niters = reduce(operator.mul, self.shape) - niters_per_section[key] = with_time_loop(niters) + niters = reduce(operator.mul, + [j - i for i, j in self._space_loop_limits.values()]) + niters_per_section[TIME_INVARIANTS.name] = niters + niters_per_section[LOOP_BODY.name] = with_time_loop(niters) for i, subsection in enumerate(self.time_loop_stencils_a): key = "%s%d" % (POST_STENCILS.name, i) @@ -264,89 +272,37 @@ def niters(self): return niters_per_section - def traffic(self, mode='realistic'): - """Summary of Bytes moved between CPU (last level cache) and DRAM, - by code section. - - :param mode: Several estimates are possible: :: - - * ideal: also known as "compulsory traffic", which is the minimum - number of bytes to be moved (ie, models an infinite cache) - * ideal_with_stores: like ideal, but a data item which is both read - and written is counted twice (load + store) - * realistic: assume that all datasets, even those that do not depend - on time, need to be re-loaded at each time iteration - """ - - assert mode in ['ideal', 'ideal_with_stores', 'realistic'] - - def access(symbol): - assert isinstance(symbol, Indexed) - # Irregular accesses (eg A[B[i]]) are counted as compulsory traffic - if any(i.atoms(Indexed) for i in symbol.indices): - return symbol - else: - return symbol.base - - def count(self, expressions): - if mode in ['ideal', 'ideal_with_stores']: - filter = lambda s: self.time_dim in s.atoms() - else: - filter = lambda s: s - reads = set(flatten([e.rhs.atoms(Indexed) for e in expressions])) - writes = set(flatten([e.lhs.atoms(Indexed) for e in expressions])) - reads = set([access(s) for s in reads if filter(s)]) - writes = set([access(s) for s in writes if filter(s)]) - if mode == 'ideal': - return len(set(reads) | set(writes)) - else: - return len(reads) + len(writes) + @property + def dataspace(self): + """Summary of data items accessed, by code section.""" + handle = self.niters + handle[LOOP_BODY.name] = reduce(operator.mul, (self.nt,) + self.shape) + return handle - niters = self.niters + @property + def traffic(self): + """Summary of Bytes moved between CPU (last level cache) and DRAM, + by code section.""" dsize = np.dtype(self.dtype).itemsize - - bytes_per_section = {} - - for i, subsection in enumerate(self.time_loop_stencils_b): - key = "%s%d" % (PRE_STENCILS.name, i) - if isinstance(subsection, Iteration): - expressions = [e.stencil for e in subsection.expressions] - else: - expressions = subsection.stencil - bytes_per_section[key] = dsize*count(self, expressions)*niters[key] - - key = LOOP_BODY.name - bytes_per_section[key] = dsize*count(self, self.stencils)*niters[key] - - for i, subsection in enumerate(self.time_loop_stencils_a): - key = "%s%d" % (POST_STENCILS.name, i) - if isinstance(subsection, Iteration): - expressions = [e.stencil for e in subsection.expressions] - else: - expressions = subsection.stencil - bytes_per_section[key] = dsize*count(self, expressions)*niters[key] - - return bytes_per_section + return {k: dsize*self.dataspace[k]*v for k, v in self._memory.items()} @property def gflops(self): """Summary of GFlops performed, by code section.""" niters = self.niters - - gflops_per_iteration = self.profiler.gflops gflops_per_section = {} for i, subsection in enumerate(self.time_loop_stencils_b): key = "%s%d" % (PRE_STENCILS.name, i) - gflops_per_section[key] = gflops_per_iteration[key]*niters[key] + gflops_per_section[key] = self._ops[key]*niters[key] key = LOOP_BODY.name - gflops_per_section[key] = gflops_per_iteration[key]*niters[key] + gflops_per_section[key] = self._ops[key]*niters[key] for i, subsection in enumerate(self.time_loop_stencils_a): key = "%s%d" % (POST_STENCILS.name, i) - gflops_per_section[key] = gflops_per_iteration[key]*niters[key] + gflops_per_section[key] = self._ops[key]*niters[key] return gflops_per_section @@ -416,38 +372,22 @@ def sympy_to_cgen(self, stencils): :param stencils: A list of stencils to be converted :returns: :class:`cgen.Block` containing the converted kernel """ - - factors = [] - if len(self.factorized) > 0: - for name, term in zip(self.factorized.keys(), self.factorized): - expr = self.factorized[name] - self.add_local_var(name, self.dtype) - sub = str(ccode(self.time_substitutions(expr).xreplace(self._mapper))) - if self.dtype is np.float32: - factors.append(cgen.Assign(name, (sub.replace("pow", "powf") - .replace("fabs", "fabsf")))) - else: - factors.append(cgen.Assign(name, sub)) - - decl = [] - + declarations = [] declared = defaultdict(bool) for eqn in stencils: s_lhs = str(eqn.lhs) - if s_lhs.find("temp") is not -1 and not declared[s_lhs]: - expr_dtype = dse_dtype(eqn.rhs) or self.dtype + if s_lhs.find(_temp_prefix) is not -1 and not declared[s_lhs]: + expr_dtype = retrieve_dtype(eqn.rhs) or self.dtype declared[s_lhs] = True - decl.append(cgen.Value(cgen.dtype_to_ctype(expr_dtype), - ccode(eqn.lhs))) + value = cgen.Value(cgen.dtype_to_ctype(expr_dtype), ccode(eqn.lhs)) + declarations.append(value) stmts = [self.convert_equality_to_cgen(x) for x in stencils] - for idx, dec in enumerate(decl): + for idx, dec in enumerate(declarations): stmts[idx] = cgen.Assign(dec.inline(), stmts[idx].rvalue) - kernel = stmts - - return cgen.Block(factors+kernel) + return cgen.Block(stmts) def convert_equality_to_cgen(self, equality): """Convert given equality to :class:`cgen.Generable` statement @@ -474,19 +414,17 @@ def convert_equality_to_cgen(self, equality): return cgen.Assign(s_lhs, s_rhs) - def get_aligned_pragma(self, stencils, factorized, time_steppers): + def get_aligned_pragma(self, stencils, time_steppers): """ Sets the alignment for the pragma. :param stencils: List of stencils. - :param factorized: dict of factorized elements :param time_steppers: list of time stepper symbols """ array_names = set() for item in flatten([stencil.free_symbols for stencil in stencils]): if ( - str(item) not in factorized - and item not in self._mapper.values() + time_steppers - and str(item).find("temp") == -1 + item not in self._mapper.values() + time_steppers + and str(item).find(_temp_prefix) == -1 ): array_names.add(item) if len(array_names) == 0: @@ -496,7 +434,37 @@ def get_aligned_pragma(self, stencils, factorized, time_steppers): ", ".join([str(i) for i in array_names]) )) - def generate_loops(self, loop_body): + def _optimize(self): + """ + Use the Devito Symbolic Engine to reduce the operation count of stencils + and any other expressions appearing in the kernel. + """ + + handle = rewrite(self.stencils, mode=self._dse) + + self.stencils = handle.time_varying + self._ops[LOOP_BODY.name] = handle.ops_time_varying + self._memory[LOOP_BODY.name] = handle.memory_time_varying + + self.time_invariants.extend(handle.time_invariants) + self._ops[TIME_INVARIANTS.name] = handle.ops_time_invariants + self._memory[TIME_INVARIANTS.name] = handle.memory_time_invariants + + sections = OrderedDict() + sections[PRE_STENCILS.name] = self.time_loop_stencils_b + sections[POST_STENCILS.name] = self.time_loop_stencils_a + for k, section in sections.items(): + for i, subsection in enumerate(section): + if isinstance(subsection, Iteration): + exprs = [e.stencil for e in subsection.expressions] + else: + exprs = subsection.stencil + handle = rewrite(exprs, mode='noop') + key = "%s%d" % (k, i) + self._ops[key] = handle.ops_time_varying + self._memory[key] = handle.memory_time_varying + + def generate_loops(self): """Assuming that the variable order defined in init (#var_order) is the order the corresponding dimensions are layout in memory, the last variable in that definition should be the fastest varying dimension in the arrays. @@ -506,6 +474,22 @@ def generate_loops(self, loop_body): :param loop_body: Statement representing the loop body :returns: :class:`cgen.Block` representing the loop """ + + if self.loop_body: + time_invariants = [] + loop_body = self.loop_body + elif self.time_order: + time_invariants = self.time_invariants + loop_body = self.sympy_to_cgen(self.stencils) + else: + time_invariants = [] + loop_body = self.sympy_to_cgen(self.stencils) + + # Init code before the time loop + header = [cgen.Value("int", i.name) for i in self.time_steppers] + # Clean-up code after the time loop + bottom = [] + # Space loops if not isinstance(loop_body, cgen.Block) or len(loop_body.contents) > 0: if self.cache_blocking is not None: @@ -532,6 +516,40 @@ def generate_loops(self, loop_body): time_stepping = [] if len(loop_body) > 0: loop_body = [cgen.Block(omp_for + loop_body)] + + # Generate code to be inserted outside of the space loops + if time_invariants: + ctype = cgen.dtype_to_ctype(self.dtype) + getname = lambda i: i.lhs.base if isinstance(i.lhs, Indexed) else i.lhs + values = { + "type": ctype, + "name": "%(name)s", + "dsize": "".join("[%d]" % j for j in self.shape[:-1]), + "size": "".join("[%d]" % j for j in self.shape) + } + declaration = "%(type)s (*%(name)s)%(dsize)s;" % values + header.extend([cgen.Line(declaration % {'name': getname(i)}) + for i in time_invariants]) + funcall = "posix_memalign((void**)&%(name)s, 64, sizeof(%(type)s%(size)s));" + funcall = funcall % values + funcalls = [cgen.Line(funcall % {'name': getname(i)}) + for i in time_invariants] + bottom = [cgen.Line('free(%s);' % getname(i)) for i in time_invariants] + time_invariants = [self.convert_equality_to_cgen(i) + for i in time_invariants] + time_invariants = self.generate_space_loops(cgen.Block(time_invariants), + full=True) + time_invariants = [cgen.Block(funcalls + omp_for + time_invariants)] + if self.profile: + time_invariants = self.profiler.add_profiling(time_invariants, + TIME_INVARIANTS.name) + header.extend(time_invariants) + + # Avoid denormal numbers + extra = [cgen.Line('_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);'), + cgen.Line('_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);')] + header.extend(omp_parallel + [cgen.Block(extra)]) + # Statements to be inserted into the time loop before the spatial loop pre_stencils = [self.time_substitutions(x) for x in self.time_loop_stencils_b] @@ -574,15 +592,11 @@ def generate_loops(self, loop_body): t_var + "+=" + str(self._time_step), loop_body ) + loop_body = omp_parallel + [cgen.Block([loop_body])] - # Code to declare the time stepping variables (outside the time loop) - def_time_step = [cgen.Value("int", t_var_def.name) - for t_var_def in self.time_steppers] - body = def_time_step + omp_parallel + [loop_body] - - return cgen.Block(body) + return header + loop_body + bottom - def generate_space_loops(self, loop_body): + def generate_space_loops(self, loop_body, full=False): """Generate list for a non cache blocking space loop :param loop_body: Statement representing the loop body :returns: :list a list of for loops @@ -592,6 +606,9 @@ def generate_space_loops(self, loop_body): for spc_var in reversed(list(self.space_dims)): dim_var = self._mapper[spc_var] loop_limits = self._space_loop_limits[spc_var] + if full: + loop_limits = (loop_limits[0] - self.spc_border, + loop_limits[1] + self.spc_border) loop_body = cgen.For( cgen.InlineInitializer(cgen.Value("int", dim_var), str(loop_limits[0])), str(dim_var) + "<" + str(loop_limits[1]), @@ -733,10 +750,10 @@ def add_inner_most_dim_pragma(self, inner_most_dim, space_dims, loop_body): :return: cgen.Block - loop body with pragma """ if inner_most_dim and len(space_dims) > 1: - pragma = [self.get_aligned_pragma(self.sub_stencils, self.factorized, - self.time_steppers)]\ - if self.compiler.openmp else (self.compiler.pragma_ivdep + - self.compiler.pragma_nontemporal) + if self.compiler.openmp: + pragma = [self.get_aligned_pragma(self.sub_stencils, self.time_steppers)] + else: + pragma = self.compiler.pragma_ivdep + self.compiler.pragma_nontemporal loop_body = cgen.Block(pragma + [loop_body]) return loop_body @@ -816,7 +833,7 @@ def add_local_var(self, name, dtype): return symbols(name) - def get_fd(self): + def _get_fd(self): """Get a FunctionDescriptor that describes the code represented by this Propagator in the format that FunctionManager and JitManager can deal with it. Before calling, make sure you have either called set_jit_params @@ -824,13 +841,8 @@ def get_fd(self): :returns: The resulting :class:`devito.function_manager.FunctionDescriptor` """ - # Assume we have been given a a loop body in cgen types - if self.loop_body is not None: - self.fd.set_body(self.generate_loops(self.loop_body)) - else: # We might have been given Sympy expression to evaluate - # This is the more common use case so this will show up in error messages - self.fd.set_body(self.generate_loops(self.sympy_to_cgen(self.stencils))) + self.fd.set_body(self.generate_loops()) return self.fd def get_time_stepping(self): @@ -878,17 +890,9 @@ def time_substitutions(self, sympy_expr): for arg in postorder_traversal(sympy_expr): if isinstance(arg, Indexed): - array_term = arg - - if not str(array_term.base.label) in self.save_vars: - raise ValueError( - "Invalid variable '%s' in sympy expression." - " Did you add it to the operator's params?" - % str(array_term.base.label) - ) - - if not self.save_vars[str(array_term.base.label)]: - subs_dict[arg] = array_term.xreplace(self.t_replace) + is_saved = self.save_vars.get(str(arg.base.label), True) + if not is_saved: + subs_dict[arg] = arg.xreplace(self.t_replace) return sympy_expr.xreplace(subs_dict) @@ -901,6 +905,7 @@ def __init__(self, name): self.name = name +TIME_INVARIANTS = Section('time_invariants') PRE_STENCILS = Section('pre_stencils') LOOP_BODY = Section('loop_body') POST_STENCILS = Section('post_stencils') diff --git a/devito/symbolics.py b/devito/symbolics.py deleted file mode 100644 index 978eacacb7..0000000000 --- a/devito/symbolics.py +++ /dev/null @@ -1,481 +0,0 @@ -""" -The Devito symbolic engine is built on top of SymPy and provides two -classes of functions: -- for inspection of expressions -- for (in-place) manipulation of expressions -- for creation of new objects given some expressions -All exposed functions are prefixed with 'dse' (devito symbolic engine) -""" - -from __future__ import absolute_import - -from collections import OrderedDict - -import numpy as np -import sympy -from sympy import (Eq, Expr, Function, Indexed, IndexedBase, S, Symbol, - collect, collect_const, count_ops, cse, flatten, lambdify, - numbered_symbols, preorder_traversal, symbols) -from sympy.core.basic import _aresame - -from devito.dimension import t, x, y, z -from devito.interfaces import SymbolicData -from devito.logger import perfbad, perfok, warning - -__all__ = ['dse_dimensions', 'dse_symbols', 'dse_dtype', 'dse_indexify', - 'dse_tolambda', 'dse_rewrite'] - -_temp_prefix = 'temp' - - -# Inspection - -def dse_dimensions(expr): - """ - Collect all function dimensions used in a sympy expression. - """ - dimensions = [] - - for e in preorder_traversal(expr): - if isinstance(e, SymbolicData): - dimensions += [i for i in e.indices if i not in dimensions] - - return dimensions - - -def dse_symbols(expr): - """ - Collect defined and undefined symbols used in a sympy expression. - - Defined symbols are functions that have an associated :class - SymbolicData: object, or dimensions that are known to the devito - engine. Undefined symbols are generic `sympy.Function` or - `sympy.Symbol` objects that need to be substituted before generating - operator C code. - """ - defined = set() - undefined = set() - - for e in preorder_traversal(expr): - if isinstance(e, SymbolicData): - defined.add(e.func(*e.indices)) - elif isinstance(e, Function): - undefined.add(e) - elif isinstance(e, Symbol): - undefined.add(e) - - return list(defined), list(undefined) - - -def dse_dtype(expr): - """ - Try to infer the data type of an expression. - """ - dtypes = [e.dtype for e in preorder_traversal(expr) if hasattr(e, 'dtype')] - return np.find_common_type(dtypes, []) - - -# Manipulation - -def dse_indexify(expr): - """ - Convert functions into indexed matrix accesses in sympy expression. - - :param expr: sympy function expression to be converted. - """ - replacements = {} - - for e in preorder_traversal(expr): - if hasattr(e, 'indexed'): - replacements[e] = e.indexify() - - return expr.xreplace(replacements) - - -def dse_rewrite(expr, mode='advanced'): - """ - Transform expressions to reduce their operation count. - - :param expr: the target expression - :param mode: drive the expression transformation. Available modes are - ['basic', 'advanced' (default)]. Currently, with 'basic', only - common sub-expressions elimination is applied. With 'advanced', - all transformations applied in 'basic' are applied, plus - factorization of common terms and constants. - """ - - if mode is True: - return Rewriter(expr).run(mode='advanced') - elif mode in ['basic', 'advanced']: - return Rewriter(expr).run(mode) - elif not mode: - return expr - else: - warning("Illegal rewrite mode %s" % str(mode)) - return expr - - -class Rewriter(object): - - """ - Transform expressions to reduce their operation count. - """ - - # Do more factorization sweeps if the expression operation count is - # greater than this threshold - FACTORIZER_THS = 15 - - def __init__(self, expr): - self.expr = expr - - def run(self, mode): - processed = self.expr - - if mode in ['basic', 'advanced']: - processed = self._cse() - - if mode in ['advanced']: - processed = self._factorize(processed) - - processed = self._finalize(processed) - - return processed - - def _factorize(self, exprs): - """ - Collect terms in each expr in exprs based on the following heuristic: - - * Collect all literals; - * Collect all temporaries produced by CSE; - * If the expression has an operation count higher than - self.FACTORIZER_THS, then this is applied recursively until - no more factorization opportunities are available. - """ - if exprs is None: - exprs = self.expr - if not isinstance(exprs, list): - exprs = [exprs] - - processed = [] - cost_original, cost_processed = 1, 1 - for expr in exprs: - handle = collect_nested(expr) - - cost_expr = estimate_cost(expr) - cost_original += cost_expr - - cost_handle = estimate_cost(handle) - - if cost_handle < cost_expr and cost_handle >= Rewriter.FACTORIZER_THS: - handle_prev = handle - cost_prev = cost_expr - while cost_handle < cost_prev: - handle_prev, handle = handle, collect_nested(handle) - cost_prev, cost_handle = cost_handle, estimate_cost(handle) - cost_handle, handle = cost_prev, handle_prev - - processed.append(handle) - cost_processed += cost_handle - - out = perfok if cost_processed < cost_original else perfbad - out("Rewriter: %d --> %d flops (Gain: %.2f X)" % - (cost_original, cost_processed, float(cost_original)/cost_processed)) - - return processed - - def _cse(self, exprs=None): - """ - Perform common subexpression elimination. - """ - if exprs is None: - exprs = self.expr - if not isinstance(exprs, list): - exprs = [exprs] - - temps, stencils = cse(exprs, numbered_symbols(_temp_prefix)) - - # Restores the LHS - for i in range(len(exprs)): - stencils[i] = Eq(exprs[i].lhs, stencils[i].rhs) - - to_revert = {} - to_keep = [] - - # Restores IndexedBases if they are collected by CSE and - # reverts changes to simple index operations (eg: t - 1) - for temp, value in temps: - if isinstance(value, IndexedBase): - to_revert[temp] = value - elif isinstance(value, Indexed): - to_revert[temp] = value - elif isinstance(value, sympy.Add) and not \ - set([t, x, y, z]).isdisjoint(set(value.args)): - to_revert[temp] = value - else: - to_keep.append((temp, value)) - - # Restores the IndexedBases and the Indexes in the assignments to revert - for temp, value in to_revert.items(): - s_dict = {} - for arg in preorder_traversal(value): - if isinstance(arg, Indexed): - new_indices = [] - for index in arg.indices: - if index in to_revert: - new_indices.append(to_revert[index]) - else: - new_indices.append(index) - if arg.base.label in to_revert: - s_dict[arg] = Indexed(to_revert[value.base.label], *new_indices) - to_revert[temp] = value.xreplace(s_dict) - - subs_dict = {} - - # Builds a dictionary of the replacements - for expr in stencils + [assign for temp, assign in to_keep]: - for arg in preorder_traversal(expr): - if isinstance(arg, Indexed): - new_indices = [] - for index in arg.indices: - if index in to_revert: - new_indices.append(to_revert[index]) - else: - new_indices.append(index) - if arg.base.label in to_revert: - subs_dict[arg] = Indexed(to_revert[arg.base.label], *new_indices) - elif tuple(new_indices) != arg.indices: - subs_dict[arg] = Indexed(arg.base, *new_indices) - if arg in to_revert: - subs_dict[arg] = to_revert[arg] - - def recursive_replace(handle, subs_dict): - replaced = [] - for i in handle: - old, new = i, i.xreplace(subs_dict) - while new != old: - old, new = new, new.xreplace(subs_dict) - replaced.append(new) - return replaced - - stencils = recursive_replace(stencils, subs_dict) - to_keep = recursive_replace([Eq(temp, assign) for temp, assign in to_keep], - subs_dict) - - # If the RHS of a temporary variable is the LHS of a stencil, - # update the value of the temporary variable after the stencil - new_stencils = [] - for stencil in stencils: - new_stencils.append(stencil) - for temp in to_keep: - if stencil.lhs in preorder_traversal(temp.rhs): - new_stencils.append(temp) - break - - # Reshuffle to make sure temporaries come later than their read values - processed = OrderedDict([(i.lhs, i) for i in to_keep + new_stencils]) - temporaries = set(processed.keys()) - ordered = OrderedDict() - while processed: - k, v = processed.popitem(last=False) - temporary_reads = terminals(v.rhs) & temporaries - {v.lhs} - if all(i in ordered for i in temporary_reads): - ordered[k] = v - else: - # Must wait for some earlier temporaries, push back into queue - processed[k] = v - - return list(ordered.values()) - - def _finalize(self, exprs): - """ - Make sure that any subsequent sympy operation applied to the expressions - in ``exprs`` does not alter the structure of the transformed objects. - """ - return [unevaluate_arithmetic(e) for e in exprs] - - -# Creation - -def dse_tolambda(exprs): - """ - Tranform an expression into a lambda. - - :param exprs: an expression or a list of expressions. - """ - exprs = exprs if isinstance(exprs, list) else [exprs] - - lambdas = [] - - for expr in exprs: - terms = free_terms(expr.rhs) - term_symbols = [symbols("i%d" % i) for i in range(len(terms))] - - # Substitute IndexedBase references to simple variables - # lambdify doesn't support IndexedBase references in expressions - tolambdify = expr.rhs.subs(dict(zip(terms, term_symbols))) - lambdified = lambdify(term_symbols, tolambdify) - lambdas.append((lambdified, terms)) - - return lambdas - - -# Utilities - -def free_terms(expr): - """ - Find the free terms in an expression. - """ - found = [] - - for term in expr.args: - if isinstance(term, Indexed): - found.append(term) - else: - found += free_terms(term) - - return found - - -def terminals(expr, discard_indexed=False): - indexed = list(expr.find(Indexed)) - - # To be discarded - junk = flatten(i.atoms() for i in indexed) - - symbols = list(expr.find(Symbol)) - symbols = [i for i in symbols if i not in junk] - - if discard_indexed: - return set(symbols) - else: - return set(indexed + symbols) - - -def collect_nested(expr): - """ - Collect terms appearing in expr, checking all levels of the expression tree. - - :param expr: the expression to be factorized. - """ - - def run(expr): - # Return semantic (rebuilt expression, factorization candidates) - - if expr.is_Float: - return expr.func(*expr.atoms()), [expr] - elif isinstance(expr, Indexed): - return expr.func(*expr.args), [] - elif expr.is_Symbol: - return expr.func(expr.name), [expr] - elif expr in [S.Zero, S.One, S.NegativeOne, S.Half]: - return expr.func(), [expr] - elif expr.is_Atom: - return expr.func(*expr.atoms()), [] - elif expr.is_Add: - rebuilt, candidates = zip(*[run(arg) for arg in expr.args]) - - w_numbers = [i for i in rebuilt if any(j.is_Number for j in i.args)] - wo_numbers = [i for i in rebuilt if i not in w_numbers] - - w_numbers = collect_const(expr.func(*w_numbers)) - wo_numbers = expr.func(*wo_numbers) - - if wo_numbers: - for i in flatten(candidates): - wo_numbers = collect(wo_numbers, i) - - rebuilt = expr.func(w_numbers, wo_numbers) - return rebuilt, [] - elif expr.is_Mul: - rebuilt, candidates = zip(*[run(arg) for arg in expr.args]) - rebuilt = collect_const(expr.func(*rebuilt)) - return rebuilt, flatten(candidates) - else: - rebuilt, candidates = zip(*[run(arg) for arg in expr.args]) - return expr.func(*rebuilt), flatten(candidates) - - return run(expr)[0] - - -def estimate_cost(handle): - try: - # Is it a plain SymPy object ? - iter(handle) - except TypeError: - handle = [handle] - try: - # Is it a dict ? - handle = handle.values() - except AttributeError: - try: - # Must be a list of dicts then - handle = flatten(i.values() for i in handle) - except AttributeError: - pass - try: - # At this point it must be a list of SymPy objects - # We don't count non floating point operations - handle = [i.rhs if i.is_Equality else i for i in handle] - total_ops = sum(count_ops(i.args) for i in handle) - non_flops = sum(count_ops(i.find(Indexed)) for i in handle) - return total_ops - non_flops - except: - warning("Cannot estimate cost of %s" % str(handle)) - - -def unevaluate_arithmetic(expr): - """ - Reconstruct ``expr`` turning all :class:`sympy.Mul` and :class:`sympy.Add` - into, respectively, :class:`devito.Mul` and :class:`devito.Add`. - """ - if expr.is_Float: - return expr.func(*expr.atoms()) - elif isinstance(expr, Indexed): - return expr.func(*expr.args) - elif expr.is_Symbol: - return expr.func(expr.name) - elif expr in [S.Zero, S.One, S.NegativeOne, S.Half]: - return expr.func() - elif expr.is_Atom: - return expr.func(*expr.atoms()) - if expr.is_Add: - rebuilt_args = [unevaluate_arithmetic(e) for e in expr.args] - return Add(*rebuilt_args, evaluate=False) - elif expr.is_Mul: - rebuilt_args = [unevaluate_arithmetic(e) for e in expr.args] - return Mul(*rebuilt_args, evaluate=False) - else: - return expr.func(*[unevaluate_arithmetic(e) for e in expr.args]) - - -# Extended Sympy hierarchy - -class UnevaluatedExpr(Expr): - - """ - Use :class:`UnevaluatedExpr` in place of :class:`sympy.Expr` to prevent - xreplace from unpicking factorizations. - """ - - def xreplace(self, rule): - if self in rule: - return rule[self] - elif rule: - args = [] - for a in self.args: - try: - args.append(a.xreplace(rule)) - except AttributeError: - args.append(a) - args = tuple(args) - if not _aresame(args, self.args): - return self.func(*args, evaluate=False) - return self - - -class Mul(sympy.Mul, UnevaluatedExpr): - pass - - -class Add(sympy.Add, UnevaluatedExpr): - pass diff --git a/examples/benchmark.py b/examples/benchmark.py index 161099364b..3553baa368 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -34,6 +34,9 @@ parser.add_argument(dest="execmode", nargs="?", default="run", choices=["run", "test", "bench", "plot"], help="Exec modes") + parser.add_argument("--bench-mode", "-bm", dest="benchmode", default="all", + choices=["all", "blocking", "dse"], + help="Choose what to benchmark (all, blocking, dse).") parser.add_argument(dest="compiler", nargs="?", default=environ.get("DEVITO_ARCH", "gnu"), choices=compiler_registry.keys(), @@ -62,7 +65,9 @@ type=int, help="End time of the simulation in ms") devito = parser.add_argument_group("Devito") - devito.add_argument("-dse", default="advanced", choices=["basic", "advanced"], + devito.add_argument("-dse", default="advanced", nargs="*", + choices=["basic", "factorize", "approx-trigonometry", + "glicm", "advanced"], help="Devito symbolic engine (DSE) mode") devito.add_argument("-a", "--auto_tuning", action="store_true", help=("Benchmark with auto tuning on and off. " + @@ -92,6 +97,7 @@ parameters = vars(args).copy() del parameters["execmode"] + del parameters["benchmode"] del parameters["problem"] del parameters["resultsdir"] del parameters["plotdir"] @@ -118,11 +124,18 @@ raise ImportError("Could not find opescibench utility package.\n" "Please install from https://github.com/opesci/opescibench") - if parameters["auto_tuning"]: + if args.benchmode == 'all': parameters["auto_tuning"] = [True, False] - - if parameters["dse"]: parameters["dse"] = ["basic", "advanced"] + elif args.benchmode == 'blocking': + parameters["auto_tuning"] = [True, False] + parameters["dse"] = ["basic"] + elif args.benchmode == 'dse': + parameters["auto_tuning"] = [False] + parameters["dse"] = ["basic", + ('basic', 'factorize'), + ('basic', 'glicm'), + "advanced"] if args.execmode == "test": values_sweep = [v if isinstance(v, list) else [v] for v in parameters.values()] @@ -184,10 +197,16 @@ def run(self, *args, **kwargs): parameters["time_order"]) at_runs = [True, False] - dse_runs = ["basic", "advanced"] + dse_runs = ["basic", ("basic", "factorize"), ("basic", "glicm"), "advanced"] runs = list(product(at_runs, dse_runs)) - style = {(True, 'advanced'): 'ob', (True, 'basic'): 'or', - (False, 'advanced'): 'Db', (False, 'basic'): 'Dr'} + styles = { + # AT true + (True, 'advanced'): 'ob', (True, 'basic'): 'or', + (True, ('basic', 'factorize')): 'og', (True, ('basic', 'glicm')): 'oy', + # AT false + (False, 'advanced'): 'Db', (False, 'basic'): 'Dr', + (False, ('basic', 'factorize')): 'Dg', (False, ('basic', 'glicm')): 'Dy' + } with RooflinePlotter(figname=name, plotdir=args.plotdir, max_bw=args.max_bw, max_flops=args.max_flops, @@ -204,6 +223,6 @@ def run(self, *args, **kwargs): point_annotate = {'s': "%.1f s" % time_value, 'xytext': (-22, 18), 'size': 6, 'weight': 'bold'} if args.point_runtime else None - plot.add_point(gflops=gflopss, oi=oi_value, style=style[run], + plot.add_point(gflops=gflopss, oi=oi_value, style=styles[run], oi_line=run[0], label=label, oi_annotate=oi_annotate, annotate=point_annotate) diff --git a/examples/tti/tti_operators.py b/examples/tti/tti_operators.py index 345e12278f..b5ad7a23ac 100644 --- a/examples/tti/tti_operators.py +++ b/examples/tti/tti_operators.py @@ -17,12 +17,10 @@ class ForwardOperator(Operator): :param data: IShot() object containing the acquisition geometry and field data :param: time_order: Time discretization order :param: spc_order: Space discretization order - :param: trigonometry : COS/SIN functions choice. The default is to use C functions :param: u_ini : wavefield at the three first time step for non-zero initial condition - `Bhaskara` uses a rational approximation. """ - def __init__(self, model, src, damp, data, time_order=2, spc_order=4, save=False, - trigonometry='normal', u_ini=None, **kwargs): + def __init__(self, model, src, damp, data, time_order=2, spc_order=4, + save=False, u_ini=None, **kwargs): nt, nrec = data.shape nt, nsrc = src.shape dt = model.get_critical_dt() @@ -95,40 +93,18 @@ def __init__(self, model, src, damp, data, time_order=2, spc_order=4, save=False ndim=len(damp.shape), dtype=damp.dtype, nbpml=model.nbpml) source.data[:] = .5*src.traces[:] + s, h = symbols('s h') - def ssin(angle, approx): - if angle == 0: - return 0.0 - else: - if approx == 'Bhaskara': - return (16.0 * angle * (3.1416 - abs(angle)) / - (49.3483 - 4.0 * abs(angle) * (3.1416 - abs(angle)))) - elif approx == 'Taylor': - return angle - (angle * angle * angle / 6.0 * - (1.0 - angle * angle / 20.0)) - else: - return sin(angle) - - def ccos(angle, approx): - if angle == 0: - return 1.0 - else: - if approx == 'Bhaskara': - return ssin(angle, 'Bhaskara') - elif approx == 'Taylor': - return 1 - .5 * angle * angle * (1 - angle * angle / 12.0) - else: - return cos(angle) - - ang0 = ccos(theta, trigonometry) - ang1 = ssin(theta, trigonometry) spc_brd = spc_order/2 - # Derive stencil from symbolic equation + ang0 = cos(theta) + ang1 = sin(theta) if len(m.shape) == 3: - ang2 = ccos(phi, trigonometry) - ang3 = ssin(phi, trigonometry) + ang2 = cos(phi) + ang3 = sin(phi) + + # Derive stencil from symbolic equation Gyp = (ang3 * u.dx - ang2 * u.dyr) Gyy = (-first_derivative(Gyp * ang3, dim=x, side=centered, order=spc_brd) - diff --git a/tests/test_gradient.py b/tests/test_gradient.py index 9b52c85395..770c746d89 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -7,7 +7,6 @@ from examples.containers import IGrid, IShot -@pytest.mark.xfail(reason='Numerical accuracy with np.float32') class TestGradient(object): @pytest.fixture(params=[(70, 80)]) def acoustic(self, request, time_order, space_order): @@ -47,7 +46,7 @@ def smooth10(vel): else: dt = model.get_critical_dt() t0 = 0.0 - tn = 500.0 + tn = 750.0 nt = int(1+(tn-t0)/dt) # Set up the source as Ricker wavelet for f0 @@ -89,7 +88,7 @@ def source(t, f0): def time_order(self, request): return request.param - @pytest.fixture(params=[2]) + @pytest.fixture(params=[4]) def space_order(self, request): return request.param @@ -102,7 +101,7 @@ def test_grad(self, acoustic): # Actual Gradient test G = np.dot(gradient.reshape(-1), acoustic[1].model.pad(acoustic[2]).reshape(-1)) # FWI Gradient test - H = [1, 0.1, 0.01, .001, 0.0001, 0.00001, 0.000001] + H = [0.5, 0.25, .125, 0.0625, 0.0312, 0.015625, 0.0078125] error1 = np.zeros(7) error2 = np.zeros(7) for i in range(0, 7): @@ -124,8 +123,8 @@ def test_grad(self, acoustic): p2 = np.polyfit(np.log10(H), np.log10(error2), 1) print(p1) print(p2) - assert np.isclose(p1[0], 1.0, rtol=0.05) - assert np.isclose(p2[0], 2.0, rtol=0.05) + assert np.isclose(p1[0], 1.0, rtol=0.1) + assert np.isclose(p2[0], 2.0, rtol=0.1) if __name__ == "__main__": diff --git a/tests/test_rewriter.py b/tests/test_rewriter.py index 7ef648f11d..b291bbbe70 100644 --- a/tests/test_rewriter.py +++ b/tests/test_rewriter.py @@ -1,5 +1,10 @@ import numpy as np +import pytest + +from devito.dse.graph import temporaries_graph +from devito.dse.symbolics import rewrite + from examples.acoustic.Acoustic_codegen import Acoustic_cg from examples.containers import IGrid, IShot from examples.tti.tti_example import setup @@ -77,17 +82,47 @@ def tti_operator(dse=False): return handle -def test_tti_rewrite_basic(): - output1 = tti_operator(dse=None).apply() - output2 = tti_operator(dse='basic').apply() +@pytest.fixture(scope="session") +def tti_nodse(): + # FIXME: note that np.copy is necessary because of the broken caching system + output = tti_operator(dse=None).apply() + return (np.copy(output[0].data), np.copy(output[1].data)) + + +def test_tti_rewrite_temporaries_graph(): + operator = tti_operator() + handle = rewrite(operator.stencils, mode='basic') + + graph = temporaries_graph(handle.exprs) + + assert len([v for v in graph.values() if v.is_terminal]) == 2 # u and v + assert len(graph) == len(handle.exprs) + assert all(v.reads or v.readby for v in graph.values()) + + +def test_tti_rewrite_basic(tti_nodse): + output = tti_operator(dse='basic').apply() + + assert np.allclose(tti_nodse[0], output[0].data, atol=10e-3) + assert np.allclose(tti_nodse[1], output[1].data, atol=10e-3) + + +def test_tti_rewrite_factorizer(tti_nodse): + output = tti_operator(dse=('basic', 'factorize')).apply() + + assert np.allclose(tti_nodse[0], output[0].data, atol=10e-3) + assert np.allclose(tti_nodse[1], output[1].data, atol=10e-3) + + +def test_tti_rewrite_trigonometry(tti_nodse): + output = tti_operator(dse=('basic', 'approx-trigonometry')).apply() - assert np.allclose(output1[0].data, output2[0].data, atol=10e-3) - assert np.allclose(output1[1].data, output2[1].data, atol=10e-3) + assert np.allclose(tti_nodse[0], output[0].data, atol=10e-1) + assert np.allclose(tti_nodse[1], output[1].data, atol=10e-1) -def test_tti_rewrite_advanced(): - output1 = tti_operator(dse=None).apply() - output2 = tti_operator(dse='advanced').apply() +def test_tti_rewrite_advanced(tti_nodse): + output = tti_operator(dse='advanced').apply() - assert np.allclose(output1[0].data, output2[0].data, atol=10e-3) - assert np.allclose(output1[1].data, output2[1].data, atol=10e-3) + assert np.allclose(tti_nodse[0], output[0].data, atol=10e-1) + assert np.allclose(tti_nodse[1], output[1].data, atol=10e-1)