diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index def4746a18..fce8a29964 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -1,7 +1,6 @@ from collections.abc import Callable from functools import singledispatch -from numbers import Number -from textwrap import indent +from textwrap import dedent, indent from typing import Any import numba @@ -15,7 +14,6 @@ from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( create_numba_signature, - create_tuple_creator, numba_funcify, numba_njit, use_optimized_cheap_pass, @@ -26,7 +24,7 @@ encode_literals, store_core_outputs, ) -from pytensor.link.utils import compile_function_src, get_name_for_object +from pytensor.link.utils import compile_function_src from pytensor.scalar.basic import ( AND, OR, @@ -163,40 +161,32 @@ def create_vectorize_func( return elemwise_fn -def create_axis_reducer( - scalar_op: Op, - identity: np.ndarray | Number, - axis: int, - ndim: int, - dtype: numba.types.Type, +def create_multiaxis_reducer( + scalar_op, + identity, + axes, + ndim, + dtype, keepdims: bool = False, - return_scalar=False, -) -> numba.core.dispatcher.Dispatcher: - r"""Create Python function that performs a NumPy-like reduction on a given axis. +): + r"""Construct a function that reduces multiple axes. The functions generated by this function take the following form: .. code-block:: python - def careduce_axis(x): - res_shape = tuple( - shape[i] if i < axis else shape[i + 1] for i in range(ndim - 1) - ) - res = np.full(res_shape, identity, dtype=dtype) - - x_axis_first = x.transpose(reaxis_first) - - for m in range(x.shape[axis]): - reduce_fn(res, x_axis_first[m], res) - - if keepdims: - return np.expand_dims(res, axis) - else: - return res + def careduce_add(x): + # For x.ndim == 3 and axes == (0, 1) and scalar_op == "Add" + x_shape = x.shape + res_shape = x_shape[2] + res = np.full(res_shape, numba_basic.to_scalar(0.0), dtype=out_dtype) + for i0 in range(x_shape[0]): + for i1 in range(x_shape[1]): + for i2 in range(x_shape[2]): + res[i2] += x[i0, i1, i2] - This can be removed/replaced when - https://github.com/numba/numba/issues/4504 is implemented. + return res Parameters ========== @@ -204,25 +194,29 @@ def careduce_axis(x): The scalar :class:`Op` that performs the desired reduction. identity: The identity value for the reduction. - axis: - The axis to reduce. + axes: + The axes to reduce. ndim: - The number of dimensions of the result. + The number of dimensions of the input variable. dtype: The data type of the result. - keepdims: - Determines whether or not the reduced dimension is retained. - - + keepdims: boolean, default False + Whether to keep the reduced dimensions. Returns ======= A Python function that can be JITed. """ + # if len(axes) == 1: + # return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype) - axis = normalize_axis_index(axis, ndim) + axes = normalize_axis_tuple(axes, ndim) + if keepdims and len(axes) > 1: + raise NotImplementedError( + "Cannot keep multiple dimensions when reducing multiple axes" + ) - reduce_elemwise_fn_name = "careduce_axis" + careduce_fn_name = f"careduce_{scalar_op}" identity = str(identity) if identity == "inf": @@ -235,163 +229,56 @@ def careduce_axis(x): "numba_basic": numba_basic, "out_dtype": dtype, } + complete_reduction = len(axes) == ndim + kept_axis = tuple(i for i in range(ndim) if i not in axes) + + res_indices = [] + arr_indices = [] + for i in range(ndim): + index_label = f"i{i}" + arr_indices.append(index_label) + if i not in axes: + res_indices.append(index_label) + res_indices = ", ".join(res_indices) if res_indices else () + arr_indices = ", ".join(arr_indices) if arr_indices else () + + inplace_update_stmt = scalar_in_place_fn( + scalar_op, res_indices, "res", f"x[{arr_indices}]" + ) - if ndim > 1: - res_shape_tuple_ctor = create_tuple_creator( - lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1 - ) - global_env["res_shape_tuple_ctor"] = res_shape_tuple_ctor - - res_indices = [] - arr_indices = [] - count = 0 - - for i in range(ndim): - if i == axis: - arr_indices.append("i") - else: - res_indices.append(f"idx_arr[{count}]") - arr_indices.append(f"idx_arr[{count}]") - count = count + 1 - - res_indices = ", ".join(res_indices) - arr_indices = ", ".join(arr_indices) - - inplace_update_statement = scalar_in_place_fn( - scalar_op, res_indices, "res", f"x[{arr_indices}]" - ) - inplace_update_statement = indent(inplace_update_statement, " " * 4 * 3) - - return_expr = f"np.expand_dims(res, {axis})" if keepdims else "res" - reduce_elemwise_def_src = f""" -def {reduce_elemwise_fn_name}(x): - - x_shape = np.shape(x) - res_shape = res_shape_tuple_ctor(x_shape) - res = np.full(res_shape, numba_basic.to_scalar({identity}), dtype=out_dtype) - - axis_shape = x.shape[{axis}] - - for idx_arr in np.ndindex(res_shape): - for i in range(axis_shape): -{inplace_update_statement} - - return {return_expr} - """ + res_shape = f"({', '.join(f'x_shape[{i}]' for i in kept_axis)})" + if complete_reduction and ndim > 0: + # We accumulate on a scalar, not an array + res_creator = f"np.asarray({identity}).astype(out_dtype).item()" + inplace_update_stmt = inplace_update_stmt.replace("res[()]", "res") + return_obj = "np.asarray(res)" else: - inplace_update_statement = scalar_in_place_fn(scalar_op, "0", "res", "x[i]") - inplace_update_statement = indent(inplace_update_statement, " " * 4 * 2) - - return_expr = "res" if keepdims else "res.item()" - if not return_scalar: - return_expr = f"np.asarray({return_expr})" - reduce_elemwise_def_src = f""" -def {reduce_elemwise_fn_name}(x): - - res = np.full(1, numba_basic.to_scalar({identity}), dtype=out_dtype) - - axis_shape = x.shape[{axis}] - - for i in range(axis_shape): -{inplace_update_statement} - - return {return_expr} + res_creator = ( + f"np.full({res_shape}, np.asarray({identity}).item(), dtype=out_dtype)" + ) + return_obj = "res" + + if keepdims: + [axis] = axes + return_obj = f"np.expand_dims({return_obj}, {axis})" + + careduce_def_src = dedent( + f""" + def {careduce_fn_name}(x): + x_shape = x.shape + res_shape = {res_shape} + res = {res_creator} """ - - reduce_elemwise_fn_py = compile_function_src( - reduce_elemwise_def_src, reduce_elemwise_fn_name, {**globals(), **global_env} ) - - return reduce_elemwise_fn_py - - -def create_multiaxis_reducer( - scalar_op, - identity, - axes, - ndim, - dtype, - input_name="input", - return_scalar=False, -): - r"""Construct a function that reduces multiple axes. - - The functions generated by this function take the following form: - - .. code-block:: python - - def careduce_maximum(input): - axis_0_res = careduce_axes_fn_0(input) - axis_1_res = careduce_axes_fn_1(axis_0_res) - ... - axis_N_res = careduce_axes_fn_N(axis_N_minus_1_res) - return axis_N_res - - The range 0-N is determined by the `axes` argument (i.e. the - axes to be reduced). - - - Parameters - ========== - scalar_op: - The scalar :class:`Op` that performs the desired reduction. - identity: - The identity value for the reduction. - axes: - The axes to reduce. - ndim: - The number of dimensions of the result. - dtype: - The data type of the result. - return_scalar: - If True, return a scalar, otherwise an array. - - Returns - ======= - A Python function that can be JITed. - - """ - if len(axes) == 1: - return create_axis_reducer(scalar_op, identity, axes[0], ndim, dtype) - - axes = normalize_axis_tuple(axes, ndim) - - careduce_fn_name = f"careduce_{scalar_op}" - global_env = {} - to_reduce = sorted(axes, reverse=True) - careduce_lines_src = [] - var_name = input_name - - for i, axis in enumerate(to_reduce): - careducer_axes_fn_name = f"careduce_axes_fn_{i}" - reducer_py_fn = create_axis_reducer(scalar_op, identity, axis, ndim, dtype) - reducer_fn = numba_basic.numba_njit( - boundscheck=False, fastmath=config.numba__fastmath - )(reducer_py_fn) - - global_env[careducer_axes_fn_name] = reducer_fn - - ndim -= 1 - last_var_name = var_name - var_name = f"axis_{i}_res" - careduce_lines_src.append( - f"{var_name} = {careducer_axes_fn_name}({last_var_name})" + for axis in range(ndim): + careduce_def_src += indent( + f"for i{axis} in range(x_shape[{axis}]):\n", + " " * (4 + 4 * axis), ) - - careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4) - if not return_scalar: - pre_result = "np.asarray" - post_result = "" - else: - pre_result = "np.asarray" - post_result = ".item()" - - careduce_def_src = f""" -def {careduce_fn_name}({input_name}): -{careduce_assign_lines} - return {pre_result}({var_name}){post_result} - """ - + careduce_def_src += indent(inplace_update_stmt, " " * (4 + 4 * ndim)) + careduce_def_src += "\n\n" + careduce_def_src += indent(f"return {return_obj}", " " * 4) + print(careduce_def_src) careduce_fn = compile_function_src( careduce_def_src, careduce_fn_name, {**globals(), **global_env} ) @@ -545,32 +432,29 @@ def ov_elemwise(*inputs): @numba_funcify.register(Sum) def numba_funcify_Sum(op, node, **kwargs): + ndim_input = node.inputs[0].ndim axes = op.axis if axes is None: axes = list(range(node.inputs[0].ndim)) - - axes = tuple(axes) - - ndim_input = node.inputs[0].ndim + else: + axes = normalize_axis_tuple(axes, ndim_input) if hasattr(op, "acc_dtype") and op.acc_dtype is not None: acc_dtype = op.acc_dtype else: acc_dtype = node.outputs[0].type.dtype - np_acc_dtype = np.dtype(acc_dtype) - out_dtype = np.dtype(node.outputs[0].dtype) if ndim_input == len(axes): - - @numba_njit(fastmath=True) + # Slightly faster than `numba_funcify_CAReduce` for this case + @numba_njit(fastmath=config.numba__fastmath) def impl_sum(array): return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype) elif len(axes) == 0: - - @numba_njit(fastmath=True) + # These cases should be removed by rewrites! + @numba_njit(fastmath=config.numba__fastmath) def impl_sum(array): return np.asarray(array, dtype=out_dtype) @@ -603,7 +487,6 @@ def numba_funcify_CAReduce(op, node, **kwargs): # Make sure it has the correct dtype scalar_op_identity = np.array(scalar_op_identity, dtype=np_acc_dtype) - input_name = get_name_for_object(node.inputs[0]) ndim = node.inputs[0].ndim careduce_py_fn = create_multiaxis_reducer( op.scalar_op, @@ -611,7 +494,6 @@ def numba_funcify_CAReduce(op, node, **kwargs): axes, ndim, np.dtype(node.outputs[0].type.dtype), - input_name=input_name, ) careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False) @@ -724,11 +606,11 @@ def numba_funcify_Softmax(op, node, **kwargs): if axis is not None: axis = normalize_axis_index(axis, x_at.ndim) - reduce_max_py = create_axis_reducer( + reduce_max_py = create_multiaxis_reducer( scalar_maximum, -np.inf, axis, x_at.ndim, x_dtype, keepdims=True ) - reduce_sum_py = create_axis_reducer( - add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True + reduce_sum_py = create_multiaxis_reducer( + add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True ) jit_fn = numba_basic.numba_njit( @@ -761,8 +643,8 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): axis = op.axis if axis is not None: axis = normalize_axis_index(axis, sm_at.ndim) - reduce_sum_py = create_axis_reducer( - add_as, 0.0, axis, sm_at.ndim, sm_dtype, keepdims=True + reduce_sum_py = create_multiaxis_reducer( + add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True ) jit_fn = numba_basic.numba_njit( @@ -793,16 +675,16 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): if axis is not None: axis = normalize_axis_index(axis, x_at.ndim) - reduce_max_py = create_axis_reducer( + reduce_max_py = create_multiaxis_reducer( scalar_maximum, -np.inf, - axis, + (axis,), x_at.ndim, x_dtype, keepdims=True, ) - reduce_sum_py = create_axis_reducer( - add_as, 0.0, axis, x_at.ndim, x_dtype, keepdims=True + reduce_sum_py = create_multiaxis_reducer( + add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True ) jit_fn = numba_basic.numba_njit( diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 3fb3979c27..ff2c23bcf3 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -15,7 +15,7 @@ from pytensor.gradient import grad from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph -from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.elemwise import CAReduce, DimShuffle from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad from tests.link.numba.test_basic import ( @@ -23,7 +23,7 @@ scalar_my_multi_out, set_test_value, ) -from tests.tensor.test_elemwise import TestElemwise +from tests.tensor.test_elemwise import TestElemwise, careduce_benchmark_tester rng = np.random.default_rng(42849) @@ -249,12 +249,12 @@ def test_Dimshuffle_non_contiguous(): ( lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x), 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), + set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x), 0, - set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)), + set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])), ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Sum( @@ -301,6 +301,24 @@ def test_Dimshuffle_non_contiguous(): pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) ), ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Prod( + axis=axis, dtype=dtype, acc_dtype=acc_dtype + )(x), + (), # Empty axes would normally be rewritten away, but we want to test it still works + set_test_value( + pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2)) + ), + ), + ( + lambda x, axis=None, dtype=None, acc_dtype=None: Prod( + axis=axis, dtype=dtype, acc_dtype=acc_dtype + )(x), + None, + set_test_value( + pt.scalar(), np.array(99.0, dtype=config.floatX) + ), # Scalar input would normally be rewritten away, but we want to test it still works + ), ( lambda x, axis=None, dtype=None, acc_dtype=None: Prod( axis=axis, dtype=dtype, acc_dtype=acc_dtype @@ -367,7 +385,7 @@ def test_CAReduce(careduce_fn, axis, v): g = careduce_fn(v, axis=axis) g_fg = FunctionGraph(outputs=[g]) - compare_numba_and_py( + fn, _ = compare_numba_and_py( g_fg, [ i.tag.test_value @@ -375,6 +393,10 @@ def test_CAReduce(careduce_fn, axis, v): if not isinstance(i, SharedVariable | Constant) ], ) + # Confirm CAReduce is in the compiled function + fn.dprint() + [node] = fn.maker.fgraph.apply_nodes + assert isinstance(node.op, CAReduce) def test_scalar_Elemwise_Clip(): @@ -619,10 +641,10 @@ def test_logsumexp_benchmark(size, axis, benchmark): X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA") # JIT compile first - _ = X_lse_fn(X_val) - res = benchmark(X_lse_fn, X_val) + res = X_lse_fn(X_val) exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True) np.testing.assert_array_almost_equal(res, exp_res) + benchmark(X_lse_fn, X_val) def test_fused_elemwise_benchmark(benchmark): @@ -653,3 +675,19 @@ def test_elemwise_out_type(): x_val = np.broadcast_to(np.zeros((3,)), (6, 3)) assert func(x_val).shape == (18,) + + +@pytest.mark.parametrize( + "axis", + (0, 1, 2, (0, 1), (0, 2), (1, 2), None), + ids=lambda x: f"axis={x}", +) +@pytest.mark.parametrize( + "c_contiguous", + (True, False), + ids=lambda x: f"c_contiguous={x}", +) +def test_careduce_benchmark(axis, c_contiguous, benchmark): + return careduce_benchmark_tester( + axis, c_contiguous, mode="NUMBA", benchmark=benchmark + ) diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 7ccc2fd95c..9f82c2675c 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -983,27 +983,33 @@ def test_CAReduce(self): assert vect_node.inputs[0] is bool_tns -@pytest.mark.parametrize( - "axis", - (0, 1, 2, (0, 1), (0, 2), (1, 2), None), - ids=lambda x: f"axis={x}", -) -@pytest.mark.parametrize( - "c_contiguous", - (True, False), - ids=lambda x: f"c_contiguous={x}", -) -def test_careduce_benchmark(axis, c_contiguous, benchmark): +def careduce_benchmark_tester(axis, c_contiguous, mode, benchmark): N = 256 x_test = np.random.uniform(size=(N, N, N)) transpose_axis = (0, 1, 2) if c_contiguous else (2, 0, 1) x = pytensor.shared(x_test, name="x", shape=x_test.shape) out = x.transpose(transpose_axis).sum(axis=axis) - fn = pytensor.function([], out) + fn = pytensor.function([], out, mode=mode) np.testing.assert_allclose( fn(), x_test.transpose(transpose_axis).sum(axis=axis), ) benchmark(fn) + + +@pytest.mark.parametrize( + "axis", + (0, 1, 2, (0, 1), (0, 2), (1, 2), None), + ids=lambda x: f"axis={x}", +) +@pytest.mark.parametrize( + "c_contiguous", + (True, False), + ids=lambda x: f"c_contiguous={x}", +) +def test_careduce_benchmark(axis, c_contiguous, benchmark): + return careduce_benchmark_tester( + axis, c_contiguous, mode="FAST_RUN", benchmark=benchmark + )