diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index cd0daffb49..f5381b3c72 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -9,8 +9,9 @@ import dataclasses import functools import inspect +import math from builtins import bool, float, int, tuple -from typing import Any, Callable, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast +from typing import Any, Callable, Final, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np from numpy import float32, float64, int32, int64 @@ -196,40 +197,55 @@ def astype( return core_defs.dtype(type_).scalar_type(value) -UNARY_MATH_NUMBER_BUILTIN_NAMES = ["abs"] - -UNARY_MATH_FP_BUILTIN_NAMES = [ - "sin", - "cos", - "tan", - "arcsin", - "arccos", - "arctan", - "sinh", - "cosh", - "tanh", - "arcsinh", - "arccosh", - "arctanh", - "sqrt", - "exp", - "log", - "gamma", - "cbrt", - "floor", - "ceil", - "trunc", -] - -UNARY_MATH_FP_PREDICATE_BUILTIN_NAMES = ["isfinite", "isinf", "isnan"] +_UNARY_MATH_NUMBER_BUILTIN_IMPL: Final = {"abs": abs} +UNARY_MATH_NUMBER_BUILTIN_NAMES: Final = [*_UNARY_MATH_NUMBER_BUILTIN_IMPL.keys()] + +_UNARY_MATH_FP_BUILTIN_IMPL: Final = { + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "arcsin": math.asin, + "arccos": math.acos, + "arctan": math.atan, + "sinh": math.sinh, + "cosh": math.cosh, + "tanh": math.tanh, + "arcsinh": math.asinh, + "arccosh": math.acosh, + "arctanh": math.atanh, + "sqrt": math.sqrt, + "exp": math.exp, + "log": math.log, + "gamma": math.gamma, + "cbrt": math.cbrt if hasattr(math, "cbrt") else np.cbrt, # match.cbrt() only added in 3.11 + "floor": math.floor, + "ceil": math.ceil, + "trunc": math.trunc, +} +UNARY_MATH_FP_BUILTIN_NAMES: Final = [*_UNARY_MATH_FP_BUILTIN_IMPL.keys()] + +_UNARY_MATH_FP_PREDICATE_BUILTIN_IMPL: Final = { + "isfinite": math.isfinite, + "isinf": math.isinf, + "isnan": math.isnan, +} +UNARY_MATH_FP_PREDICATE_BUILTIN_NAMES: Final = [*_UNARY_MATH_FP_PREDICATE_BUILTIN_IMPL.keys()] def _make_unary_math_builtin(name: str) -> None: + _math_builtin = ( + _UNARY_MATH_NUMBER_BUILTIN_IMPL + | _UNARY_MATH_FP_BUILTIN_IMPL + | _UNARY_MATH_FP_PREDICATE_BUILTIN_IMPL + )[name] + def impl(value: common.Field | core_defs.ScalarT, /) -> common.Field | core_defs.ScalarT: - # TODO(havogt): enable once we have a failing test (see `test_math_builtin_execution.py`) - # assert core_defs.is_scalar_type(value) # default implementation for scalars, Fields are handled via dispatch # noqa: ERA001 [commented-out-code] - # return getattr(math, name)(value)# noqa: ERA001 [commented-out-code] - raise NotImplementedError() + # TODO(havogt): enable tests in `test_math_builtin_execution.py` + assert core_defs.is_scalar_type( + value + ) # default implementation for scalars, Fields are handled via dispatch + + return _math_builtin(value) impl.__name__ = name globals()[name] = BuiltInFunction(impl)