Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Generate fminf/fmaxf where necessary #2501

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions devito/symbolics/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,25 @@ def _print_Rational(self, expr):
def _print_math_func(self, expr, nest=False, known=None):
cls = type(expr)
name = cls.__name__
if name not in self._prec_funcs:
return super()._print_math_func(expr, nest=nest, known=known)

try:
cname = self.known_functions[name]
except KeyError:
return super()._print_math_func(expr, nest=nest, known=known)

if cname not in self._prec_funcs:
return super()._print_math_func(expr, nest=nest, known=known)

if self.single_prec(expr):
cname = '%sf' % cname

args = ', '.join((self._print(arg) for arg in expr.args))
if nest and len(expr.args) > 2:
args = ', '.join([self._print(expr.args[0]),
self._print_math_func(cls(*expr.args[1:]))])
else:
args = ', '.join([self._print(arg) for arg in expr.args])

return '%s(%s)' % (cname, args)
return f'{cname}({args})'

def _print_Pow(self, expr):
# Need to override because of issue #1627
Expand Down
25 changes: 25 additions & 0 deletions tests/test_symbolics.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,31 @@ def test_minmax():
assert np.all(f.data == 4)


@pytest.mark.parametrize('dtype,expected', [
(np.float32, ("fmaxf", "fminf")),
(np.float64, ("fmax", "fmin")),
])
def test_minmax_precision(dtype, expected):
grid = Grid(shape=(5, 5), dtype=dtype)

f = Function(name="f", grid=grid)
g = Function(name="g", grid=grid)

eqn = Eq(f, Min(g, 4.0) + Max(g, 2.0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would use numpy numbers to check dtype is correct

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdym exactly? I can think of at least two ways...

I landed on the above because that's the natural way people write the equations and it doesn't matter whether it's 4, 4.0, or 4.0F, because in the end, g will drive the fmin/fmax generation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mloubout I'm gonna merge this because I don't like broken CI, but as soon as you elaborate on how you want this test improved, I'll mkae the change and push it together with one of the upcoming branches (at least one more to go in before christmas)


op = Operator(eqn)

g.data[:] = 3.0

op.apply()

# Check generated code -- ensure it's using the fp64 versions of min/max,
# that is fminf/fmaxf
assert all(i in str(op) for i in expected)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be good (belt and braces) to also check:
assert all(i not in str(op) for i in not_expected)
to ensure that the correct versions are used throughout?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overkill I'd say...


assert np.all(f.data == 6.0)


class TestRelationsWithAssumptions:

def test_multibounds_op(self):
Expand Down
Loading