Skip to content

Commit

Permalink
api: enforce sympy shifts
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Oct 20, 2024
1 parent 7795225 commit 556f4aa
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
6 changes: 3 additions & 3 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from itertools import product

import numpy as np
from sympy import S, finite_diff_weights, cacheit, sympify, Function
from sympy import S, finite_diff_weights, cacheit, sympify, Function, Rational

from devito.tools import Tag, as_tuple
from devito.types.dimension import StencilDimension
Expand Down Expand Up @@ -308,8 +308,8 @@ def make_shift_x0(shift, ndim):
"""
if shift is None:
return lambda s, d, i, j: None
elif isinstance(shift, float):
return lambda s, d, i, j: d + s * d.spacing
elif sympify(shift).is_Number:
return lambda s, d, i, j: d + Rational(s) * d.spacing
elif type(shift) is tuple and np.shape(shift) == ndim:
if len(ndim) == 1:
return lambda s, d, i, j: d + s[j] * d.spacing
Expand Down
13 changes: 13 additions & 0 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from mpmath.libmp import prec_to_dps, to_str
from packaging.version import Version

from sympy.codegen.ast import float32, float64
from sympy.logic.boolalg import BooleanFunction
from sympy.printing.precedence import PRECEDENCE_VALUES, precedence
from sympy.printing.c import C99CodePrinter
Expand All @@ -18,6 +20,9 @@
__all__ = ['ccode']


_type_mapper = {np.float32: float32, np.float64: float64}


class CodePrinter(C99CodePrinter):

"""
Expand Down Expand Up @@ -179,12 +184,20 @@ def _print_Add(self, expr, order=None):

def _print_Float(self, expr):
"""Print a Float in C-like scientific notation."""
try:
# Make sure the float is in the correct format
expr = _type_mapper[self.dtype].cast_nocheck(expr)
rv = str(expr)
except KeyError:
pass

prec = expr._prec

if prec < 5:
dps = 0
else:
dps = prec_to_dps(expr._prec)

if self._settings["full_prec"] is True:
strip = False
elif self._settings["full_prec"] is False:
Expand Down

0 comments on commit 556f4aa

Please sign in to comment.