Skip to content

Commit

Permalink
Merge pull request #2190 from devitocodes/no-more-dyn-classes-final
Browse files Browse the repository at this point in the history
dsl: No more dynamic classes for AbstractFunctions
  • Loading branch information
mloubout authored Aug 23, 2023
2 parents 67e5779 + 25c856a commit 7763d03
Show file tree
Hide file tree
Showing 22 changed files with 446 additions and 169 deletions.
3 changes: 3 additions & 0 deletions devito/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
from devito.mpi.routines import mpi_registry
from devito.operator import profiler_registry, operator_registry

# Apply monkey-patching while we wait for our patches to be upstreamed and released
from devito.mpatches import * # noqa


from ._version import get_versions # noqa
__version__ = get_versions()['version']
Expand Down
25 changes: 19 additions & 6 deletions devito/arch/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy.ctypeslib as npct
from codepy.jit import compile_from_string
from codepy.toolchain import GCCToolchain, call_capture_output
from codepy.toolchain import GCCToolchain, call_capture_output as _call_capture_output

from devito.arch import (AMDGPUX, Cpu64, M1, NVIDIAX, POWER8, POWER9, GRAVITON,
INTELGPUX, get_nvidia_cc, check_cuda_runtime,
Expand All @@ -19,7 +19,7 @@
from devito.logger import debug, warning, error
from devito.parameters import configuration
from devito.tools import (as_list, change_directory, filter_ordered,
memoized_func, memoized_meth, make_tempdir)
memoized_func, make_tempdir)

__all__ = ['sniff_mpi_distro', 'compiler_registry']

Expand Down Expand Up @@ -123,6 +123,15 @@ def sniff_mpi_flags(mpicc='mpicc'):
return compile_flags.split(), link_flags.split()


@memoized_func
def call_capture_output(cmd):
"""
Memoize calls to codepy's `call_capture_output` to avoid leaking memory due
to some prefork/subprocess voodoo.
"""
return _call_capture_output(cmd)


class Compiler(GCCToolchain):
"""
Base class for all compiler classes.
Expand Down Expand Up @@ -220,12 +229,16 @@ def __new_with__(self, **kwargs):
def name(self):
return self.__class__.__name__

@memoized_meth
def get_version(self):
result, stdout, stderr = call_capture_output((self.cc, "--version"))
if result != 0:
raise RuntimeError(f"version query failed: {stderr}")
return stdout

def get_jit_dir(self):
"""A deterministic temporary directory for jit-compiled objects."""
return make_tempdir('jitcache')

@memoized_meth
def get_codepy_dir(self):
"""A deterministic temporary directory for the codepy cache."""
return make_tempdir('codepy')
Expand Down Expand Up @@ -729,9 +742,9 @@ def __init__(self, *args, **kwargs):

def get_version(self):
if configuration['mpi']:
cmd = [self.cc, "-cc=%s" % self.CC, "--version"]
cmd = (self.cc, "-cc=%s" % self.CC, "--version")
else:
cmd = [self.cc, "--version"]
cmd = (self.cc, "--version")
result, stdout, stderr = call_capture_output(cmd)
if result != 0:
raise RuntimeError(f"version query failed: {stderr}")
Expand Down
4 changes: 3 additions & 1 deletion devito/builtins/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def sum(f, dims=None):
elif f.is_SparseTimeFunction:
if f.time_dim in dims:
# Sum over time -> SparseFunction
new_coords = f.coordinates._rebuild(name="%ssum_coords" % f.name)
new_coords = f.coordinates._rebuild(
name="%ssum_coords" % f.name, initializer=f.coordinates.initializer
)
out = dv.SparseFunction(name="%ssum" % f.name, grid=f.grid,
dimensions=new_dims, npoint=f.shape[1],
coordinates=new_coords)
Expand Down
61 changes: 44 additions & 17 deletions devito/finite_differences/derivative.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,18 @@ def __new__(cls, expr, *dims, **kwargs):
obj._deriv_order = orders if skip else DimensionTuple(*orders, getters=obj._dims)
obj._side = kwargs.get("side")
obj._transpose = kwargs.get("transpose", direct)
obj._ppsubs = as_tuple(frozendict(i) for i in
kwargs.get("subs", kwargs.get("_ppsubs", [])))

ppsubs = kwargs.get("subs", kwargs.get("_ppsubs", []))
processed = []
if ppsubs:
for i in ppsubs:
try:
processed.append(frozendict(i))
except AttributeError:
# E.g. `i` is a Transform object
processed.append(i)
obj._ppsubs = tuple(processed)

obj._x0 = frozendict(kwargs.get('x0', {}))
return obj

Expand Down Expand Up @@ -207,34 +217,51 @@ def _new_from_self(self, **kwargs):
def func(self, expr, *args, **kwargs):
return self._new_from_self(expr=expr, **kwargs)

def subs(self, *args, **kwargs):
"""
Bypass sympy.Subs as Devito has its own lazy evaluation mechanism.
"""
# Check if we are calling subs(self, old, new, **hint) in which case
# return the standard substitution. Need to check `==` rather than `is`
# because a new derivative could be created i.e `f.dx.subs(f.dx, y)`
if len(args) == 2 and args[0] == self:
return args[1]
try:
rules = dict(*args)
except TypeError:
rules = dict((args,))
kwargs.pop('simultaneous', None)
return self.xreplace(rules, **kwargs)
def _subs(self, old, new, **hints):
# Basic case
if old == self:
return new
# Is it in expr?
if self.expr.has(old):
newexpr = self.expr._subs(old, new, **hints)
try:
return self._new_from_self(expr=newexpr)
except ValueError:
# Expr replacement leads to non-differentiable expression
# e.g `f.dx.subs(f: 1) = 1.dx = 0`
# returning zero
return sympy.S.Zero

# In case `x0` was passed as a substitution instead of `(x0=`
if str(old) == 'x0':
return self._new_from_self(x0={self.dims[0]: new})

# Trying to substitute by another derivative with different metadata
# Only need to check if is a Derivative since one for the cases above would
# have found it
if isinstance(old, Derivative):
return self

# Fall back if we didn't catch any special case
return self.xreplace({old: new}, **hints)

def _xreplace(self, subs):
"""
This is a helper method used internally by SymPy. We exploit it to postpone
substitutions until evaluation.
"""
# Return if no subs
if not subs:
return self, False

# Check if trying to replace the whole expression
if self in subs:
new = subs.pop(self)
try:
return new._xreplace(subs)
except AttributeError:
return new, True

subs = self._ppsubs + (subs,) # Postponed substitutions
return self._new_from_self(subs=subs), True

Expand Down
8 changes: 4 additions & 4 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def _eval_at(self, func):
return self.func(*[getattr(a, '_eval_at', lambda x: a)(func) for a in self.args])

def _subs(self, old, new, **hints):
if old is self:
if old == self:
return new
if old is new:
if old == new:
return self
args = list(self.args)
for i, arg in enumerate(args):
Expand Down Expand Up @@ -613,15 +613,15 @@ def __init_finalize__(self, *args, **kwargs):

def __eq__(self, other):
return (isinstance(other, Weights) and
self.dimension is other.dimension and
self.name == other.name and
self.dimension == other.dimension and
self.indices == other.indices and
self.weights == other.weights)

__hash__ = sympy.Basic.__hash__

def _hashable_content(self):
return super()._hashable_content() + (self.name,) + tuple(self.weights)
return (self.name, self.dimension, hash(tuple(self.weights)))

@property
def dimension(self):
Expand Down
1 change: 1 addition & 0 deletions devito/mpatches/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .rationaltools import * # noqa
95 changes: 95 additions & 0 deletions devito/mpatches/rationaltools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Tools for manipulation of rational expressions. """

import importlib

import sympy
from sympy.core import Basic, Add, sympify
from sympy.core.exprtools import gcd_terms
from sympy.utilities import public
from sympy.utilities.iterables import iterable

__all__ = []


@public
def together(expr, deep=False, fraction=True):
"""
Denest and combine rational expressions using symbolic methods.
This function takes an expression or a container of expressions
and puts it (them) together by denesting and combining rational
subexpressions. No heroic measures are taken to minimize degree
of the resulting numerator and denominator. To obtain completely
reduced expression use :func:`~.cancel`. However, :func:`~.together`
can preserve as much as possible of the structure of the input
expression in the output (no expansion is performed).
A wide variety of objects can be put together including lists,
tuples, sets, relational objects, integrals and others. It is
also possible to transform interior of function applications,
by setting ``deep`` flag to ``True``.
By definition, :func:`~.together` is a complement to :func:`~.apart`,
so ``apart(together(expr))`` should return expr unchanged. Note
however, that :func:`~.together` uses only symbolic methods, so
it might be necessary to use :func:`~.cancel` to perform algebraic
simplification and minimize degree of the numerator and denominator.
Examples
========
>>> from sympy import together, exp
>>> from sympy.abc import x, y, z
>>> together(1/x + 1/y)
(x + y)/(x*y)
>>> together(1/x + 1/y + 1/z)
(x*y + x*z + y*z)/(x*y*z)
>>> together(1/(x*y) + 1/y**2)
(x + y)/(x*y**2)
>>> together(1/(1 + 1/x) + 1/(1 + 1/y))
(x*(y + 1) + y*(x + 1))/((x + 1)*(y + 1))
>>> together(exp(1/x + 1/y))
exp(1/y + 1/x)
>>> together(exp(1/x + 1/y), deep=True)
exp((x + y)/(x*y))
>>> together(1/exp(x) + 1/(x*exp(x)))
(x + 1)*exp(-x)/x
>>> together(1/exp(2*x) + 1/(x*exp(3*x)))
(x*exp(x) + 1)*exp(-3*x)/x
"""
def _together(expr):
if isinstance(expr, Basic):
if expr.is_Atom or (expr.is_Function and not deep):
return expr
elif expr.is_Add:
return gcd_terms(list(map(_together, Add.make_args(expr))),
fraction=fraction)
elif expr.is_Pow:
base = _together(expr.base)

if deep:
exp = _together(expr.exp)
else:
exp = expr.exp

return expr.func(base, exp)
else:
return expr.func(*[_together(arg) for arg in expr.args])
elif iterable(expr):
return expr.__class__([_together(ex) for ex in expr])

return expr

return _together(sympify(expr))


# Apply the monkey patch
simplify = importlib.import_module(sympy.simplify.__module__)
simplify.together = together
30 changes: 21 additions & 9 deletions devito/operations/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def linsolve(expr, target, **kwargs):
The symbol w.r.t. which the equation is rearranged. May be a `Function`
or any other symbolic object.
"""
c = factorize_target(expr, target)
c, expr = factorize_target(expr, target)
if c != 0:
return -expr.xreplace({target: 0})/c
return -expr/c
raise SolveError("No linear solution found")


Expand All @@ -102,27 +102,39 @@ def _(expr):

@singledispatch
def factorize_target(expr, target):
return 1 if expr is target else 0
return (1, 0) if expr == target else (0, expr)


@factorize_target.register(Add)
@factorize_target.register(EvalDerivative)
def _(expr, target):
c = 0
if not expr.has(target):
return c
return c, expr

args = []
for a in expr.args:
c += factorize_target(a, target)
return c
c1, a1 = factorize_target(a, target)
c += c1
args.append(a1)

return c, expr.func(*args, evaluate=False)


@factorize_target.register(Mul)
def _(expr, target):
if not expr.has(target):
return 0
return 0, expr

c = 1
args = []
for a in expr.args:
c *= a if not a.has(target) else factorize_target(a, target)
return c
if not a.has(target):
c *= a
args.append(a)
else:
c1, a1 = factorize_target(a, target)
c *= c1
args.append(a1)

return c, expr.func(*args, evaluate=False)
11 changes: 7 additions & 4 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from devito.arch.compiler import AOMPCompiler
from devito.symbolics.inspection import has_integer_args
from devito.types.basic import AbstractFunction

__all__ = ['ccode']

Expand Down Expand Up @@ -44,10 +45,12 @@ def parenthesize(self, item, level, strict=False):
return super().parenthesize(item, level, strict=strict)

def _print_Function(self, expr):
# There exist no unknown Functions
if expr.func.__name__ not in self.known_functions:
self.known_functions[expr.func.__name__] = expr.func.__name__
return super()._print_Function(expr)
if isinstance(expr, AbstractFunction):
return str(expr)
else:
if expr.func.__name__ not in self.known_functions:
self.known_functions[expr.func.__name__] = expr.func.__name__
return super()._print_Function(expr)

def _print_CondEq(self, expr):
return "%s == %s" % (self._print(expr.lhs), self._print(expr.rhs))
Expand Down
5 changes: 5 additions & 0 deletions devito/symbolics/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ def q_function(expr):
return isinstance(expr, DiscreteFunction)


def q_derivative(expr):
from devito.finite_differences.derivative import Derivative
return isinstance(expr, Derivative)


def q_terminal(expr):
return (expr.is_Symbol or
expr.is_Indexed or
Expand Down
Loading

0 comments on commit 7763d03

Please sign in to comment.