Skip to content

Commit

Permalink
sympy: Update printer to better handle precision for Pow and Abs
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Dec 20, 2024
1 parent 5a15896 commit 5e98bec
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

from mpmath.libmp import prec_to_dps, to_str
from packaging.version import Version
from numbers import Real

from sympy.core import S
from sympy.core.numbers import equal_valued, Float
from sympy.logic.boolalg import BooleanFunction
from sympy.printing.precedence import PRECEDENCE_VALUES, precedence
from sympy.printing.c import C99CodePrinter
Expand Down Expand Up @@ -122,15 +125,22 @@ def _print_math_func(self, expr, nest=False, known=None):
return f'{cname}({args})'

def _print_Pow(self, expr):
# Need to override because of issue #1627
# E.g., (Pow(h_x, -1) AND h_x.dtype == np.float32) => 1.0F/h_x
try:
if expr.exp == -1 and self.single_prec():
PREC = precedence(expr)
return '1.0F/%s' % self.parenthesize(expr.base, PREC)
except AttributeError:
pass
return super()._print_Pow(expr)
# Completely reimplement `_print_Pow` from sympy, since it doesn't
# correctly handle precision
if "Pow" in self.known_functions:
return self._print_Function(expr)
PREC = precedence(expr)
suffix = 'f' if self.single_prec(expr) else ''
if equal_valued(expr.exp, -1):
return f'{self._print_Float(Float(1.0))}/' + \
f'{self.parenthesize(expr.base, PREC)}'
elif equal_valued(expr.exp, 0.5):
return f'{self._ns}sqrt{suffix}({self._print(expr.base)}'
elif expr.exp == S.One/3 and self.standard != 'C89':
return f'{self._ns}cbrt{suffix}({self._print(expr.base)})'
else:
return f'{self._ns}pow{suffix}({self._print(expr.base)}, ' + \
f'{self._print(expr.exp)})'

def _print_Mod(self, expr):
"""Print a Mod as a C-like %-based operation."""
Expand Down Expand Up @@ -159,7 +169,16 @@ def _print_Abs(self, expr):
if isinstance(self.compiler, AOMPCompiler):
return "fabs(%s)" % self._print(expr.args[0])
# Check if argument is an integer
func = "abs" if has_integer_args(*expr.args[0].args) else "fabs"
if has_integer_args(*expr.args[0].args):
func = "abs"
elif self.single_prec(expr):
func = "fabsf"
elif any([isinstance(a, Real) for a in expr.args[0].args]):
# The previous condition isn't sufficient to detect case with
# Python `float`s in that case, fall back to the "default"
func = "fabsf" if self.single_prec() else "fabs"
else:
func = "fabs"
return "%s(%s)" % (func, self._print(expr.args[0]))

def _print_Add(self, expr, order=None):
Expand Down

0 comments on commit 5e98bec

Please sign in to comment.