From 3da00806215c2fa704c404d9ac151101d56e5817 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 18 Dec 2024 08:39:16 +0000 Subject: [PATCH] compiler: Generate fminf/fmaxf where necessary --- devito/symbolics/printer.py | 13 +++++++++---- tests/test_symbolics.py | 25 +++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/devito/symbolics/printer.py b/devito/symbolics/printer.py index 8c31d4ecf7..12eb2baca4 100644 --- a/devito/symbolics/printer.py +++ b/devito/symbolics/printer.py @@ -101,20 +101,25 @@ def _print_Rational(self, expr): def _print_math_func(self, expr, nest=False, known=None): cls = type(expr) name = cls.__name__ - if name not in self._prec_funcs: - return super()._print_math_func(expr, nest=nest, known=known) try: cname = self.known_functions[name] except KeyError: return super()._print_math_func(expr, nest=nest, known=known) + if cname not in self._prec_funcs: + return super()._print_math_func(expr, nest=nest, known=known) + if self.single_prec(expr): cname = '%sf' % cname - args = ', '.join((self._print(arg) for arg in expr.args)) + if nest and len(expr.args) > 2: + args = ', '.join([self._print(expr.args[0]), + self._print_math_func(cls(*expr.args[1:]))]) + else: + args = ', '.join([self._print(arg) for arg in expr.args]) - return '%s(%s)' % (cname, args) + return f'{cname}({args})' def _print_Pow(self, expr): # Need to override because of issue #1627 diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 496e325387..b8c64ed410 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -562,6 +562,31 @@ def test_minmax(): assert np.all(f.data == 4) +@pytest.mark.parametrize('dtype,expected', [ + (np.float32, ("fmaxf", "fminf")), + (np.float64, ("fmax", "fmin")), +]) +def test_minmax_precision(dtype, expected): + grid = Grid(shape=(5, 5), dtype=dtype) + + f = Function(name="f", grid=grid) + g = Function(name="g", grid=grid) + + eqn = Eq(f, Min(g, 4.0) + Max(g, 2.0)) + + op = Operator(eqn) + + g.data[:] = 3.0 + + op.apply() + + # Check generated code -- ensure it's using the fp64 versions of min/max, + # that is fminf/fmaxf + assert all(i in str(op) for i in expected) + + assert np.all(f.data == 6.0) + + class TestRelationsWithAssumptions: def test_multibounds_op(self):