From 5e98bec62f615436f7ff5befeeb87e3754b7909a Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Fri, 20 Dec 2024 13:29:25 +0000 Subject: [PATCH] sympy: Update printer to better handle precision for Pow and Abs --- devito/symbolics/printer.py | 39 +++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 12eb2baca4..e1a3280d31 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -7,7 +7,10 @@ from mpmath.libmp import prec_to_dps, to_str from packaging.version import Version +from numbers import Real +from sympy.core import S +from sympy.core.numbers import equal_valued, Float from sympy.logic.boolalg import BooleanFunction from sympy.printing.precedence import PRECEDENCE_VALUES, precedence from sympy.printing.c import C99CodePrinter @@ -122,15 +125,22 @@ def _print_math_func(self, expr, nest=False, known=None): return f'{cname}({args})' def _print_Pow(self, expr): - # Need to override because of issue #1627 - # E.g., (Pow(h_x, -1) AND h_x.dtype == np.float32) => 1.0F/h_x - try: - if expr.exp == -1 and self.single_prec(): - PREC = precedence(expr) - return '1.0F/%s' % self.parenthesize(expr.base, PREC) - except AttributeError: - pass - return super()._print_Pow(expr) + # Completely reimplement `_print_Pow` from sympy, since it doesn't + # correctly handle precision + if "Pow" in self.known_functions: + return self._print_Function(expr) + PREC = precedence(expr) + suffix = 'f' if self.single_prec(expr) else '' + if equal_valued(expr.exp, -1): + return f'{self._print_Float(Float(1.0))}/' + \ + f'{self.parenthesize(expr.base, PREC)}' + elif equal_valued(expr.exp, 0.5): + return f'{self._ns}sqrt{suffix}({self._print(expr.base)}' + elif expr.exp == S.One/3 and self.standard != 'C89': + return f'{self._ns}cbrt{suffix}({self._print(expr.base)})' + else: + return f'{self._ns}pow{suffix}({self._print(expr.base)}, ' + \ + f'{self._print(expr.exp)})' def _print_Mod(self, expr): """Print a Mod as a C-like %-based operation.""" @@ -159,7 +169,16 @@ def _print_Abs(self, expr): if isinstance(self.compiler, AOMPCompiler): return "fabs(%s)" % self._print(expr.args[0]) # Check if argument is an integer - func = "abs" if has_integer_args(*expr.args[0].args) else "fabs" + if has_integer_args(*expr.args[0].args): + func = "abs" + elif self.single_prec(expr): + func = "fabsf" + elif any([isinstance(a, Real) for a in expr.args[0].args]): + # The previous condition isn't sufficient to detect case with + # Python `float`s in that case, fall back to the "default" + func = "fabsf" if self.single_prec() else "fabs" + else: + func = "fabs" return "%s(%s)" % (func, self._print(expr.args[0])) def _print_Add(self, expr, order=None):