Skip to content

Commit

Permalink
api: enforce sympy shifts
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Oct 21, 2024
1 parent 7795225 commit f940580
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest-core-nompi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ jobs:
run : |
if [ "${{ runner.os }}" == 'macOS' ]; then
brew install llvm libomp
echo "/opt/homebrew/bin:/opt/homebrew/opt/llvm/bin" >> $GITHUB_PATH
echo "/opt/homebrew/opt/llvm/bin" >> $GITHUB_PATH
fi
id: set-tests

Expand Down
6 changes: 3 additions & 3 deletions devito/finite_differences/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from itertools import product

import numpy as np
from sympy import S, finite_diff_weights, cacheit, sympify, Function
from sympy import S, finite_diff_weights, cacheit, sympify, Function, Rational

from devito.tools import Tag, as_tuple
from devito.types.dimension import StencilDimension
Expand Down Expand Up @@ -308,8 +308,8 @@ def make_shift_x0(shift, ndim):
"""
if shift is None:
return lambda s, d, i, j: None
elif isinstance(shift, float):
return lambda s, d, i, j: d + s * d.spacing
elif sympify(shift).is_Number:
return lambda s, d, i, j: d + Rational(s) * d.spacing
elif type(shift) is tuple and np.shape(shift) == ndim:
if len(ndim) == 1:
return lambda s, d, i, j: d + s[j] * d.spacing
Expand Down
13 changes: 13 additions & 0 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from mpmath.libmp import prec_to_dps, to_str
from packaging.version import Version

from sympy.codegen.ast import float32, float64
from sympy.logic.boolalg import BooleanFunction
from sympy.printing.precedence import PRECEDENCE_VALUES, precedence
from sympy.printing.c import C99CodePrinter
Expand All @@ -18,6 +20,9 @@
__all__ = ['ccode']


_type_mapper = {np.float32: float32, np.float64: float64}


class CodePrinter(C99CodePrinter):

"""
Expand Down Expand Up @@ -179,12 +184,20 @@ def _print_Add(self, expr, order=None):

def _print_Float(self, expr):
"""Print a Float in C-like scientific notation."""
try:
# Make sure the float is in the correct format
expr = _type_mapper[self.dtype].cast_nocheck(expr)
rv = str(expr)
except KeyError:
pass

prec = expr._prec

if prec < 5:
dps = 0
else:
dps = prec_to_dps(expr._prec)

if self._settings["full_prec"] is True:
strip = False
elif self._settings["full_prec"] is False:
Expand Down
4 changes: 3 additions & 1 deletion tests/test_tensors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import sympy
from sympy import Rational

import pytest

Expand Down Expand Up @@ -372,7 +373,8 @@ def test_shifted_lap_of_vector(shift, ndim):
assert dfvi == ref


@pytest.mark.parametrize('shift, ndim', [(None, 2), (.5, 2), (.5, 3),
@pytest.mark.parametrize('shift, ndim', [(None, 2), (Rational(1/2), 2),
(Rational(1/2), 3),
(tuple([tuple([.5]*3)]*3), 3)])
def test_shifted_lap_of_tensor(shift, ndim):
grid = Grid(tuple([11]*ndim))
Expand Down

0 comments on commit f940580

Please sign in to comment.