diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index efcc2500a7..f11e33b41d 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -29,7 +29,7 @@ stack, switch, ) -from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import ( CAReduce, Elemwise, @@ -2726,6 +2726,22 @@ def logsumexp(x, axis=None, keepdims=False): return log(sum(exp(x), axis=axis, keepdims=keepdims)) +# Predefine all batched variations of Dot +_inner_prod = Blockwise( + _dot, + signature="(n),(n)->()", +) + +_matrix_vec_prod = Blockwise( + _dot, + signature="(m,k),(k)->(m)", +) + +_vec_matrix_prod = Blockwise( + _dot, + signature="(k),(k,n)->(n)", +) + _matrix_matrix_matmul = Blockwise( _dot, signature="(m,k),(k,n)->(m,n)", @@ -2795,14 +2811,24 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None @_vectorize_node.register(Dot) -def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y): +def vectorize_node_dot(op, node, batched_x, batched_y): old_x, old_y = node.inputs - if old_x.type.ndim == 2 and old_y.type.ndim == 2: - # If original input is equivalent to a matrix-matrix product, - # return specialized Matmul Op to avoid unnecessary new Ops. - return matmul(batched_x, batched_y).owner - else: - return vectorize_node_fallback(op, node, batched_x, batched_y) + old_x_ndim = old_x.type.ndim + old_y_ndim = old_y.type.ndim + match (old_x_ndim, old_y_ndim): + case (1, 1): + batch_op = _inner_prod + case (2, 1): + batch_op = _matrix_vec_prod + case (1, 2): + batch_op = _vec_matrix_prod + case (2, 2): + batch_op = _matrix_matrix_matmul + case _: + raise ValueError( + f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D." + ) + return batch_op(batched_x, batched_y).owner def nan_to_num(x, nan=0.0, posinf=None, neginf=None): diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index aa2d279f43..29b8ebb6cc 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -44,6 +44,10 @@ Prod, Sum, _conj, + _inner_prod, + _matrix_matrix_matmul, + _matrix_vec_prod, + _vec_matrix_prod, add, digamma, dot, @@ -242,6 +246,62 @@ def local_batched_matmul_to_core_matmul(fgraph, node): return None +@register_canonicalize +@register_specialize +@node_rewriter([_inner_prod, _matrix_vec_prod, _vec_matrix_prod, _matrix_matrix_matmul]) +def local_blockwise_dot_to_mul(fgraph, node): + """Rewrite blockwise dots that correspond to multiplication without summation. + + We don't touch the regular dot, to not interfere with the BLAS optimizations. + """ + a, b = node.inputs + a_st_shape = a.type.shape + b_st_shape = b.type.shape + core_a_ndim = len(node.op.inputs_sig[0]) + core_b_ndim = len(node.op.inputs_sig[1]) + + if core_a_ndim > 2 or core_b_ndim > 2: + # Shouldn't happen, but here just in case + return None + + if core_b_ndim == 1: + if a_st_shape[-1] == 1 or b_st_shape[-1] == 1: + if core_a_ndim == 1: + # inner product: (..., 1) * (..., 1) -> (...) + # just squeeze the last dimensions of a and b + new_a = a.squeeze(-1) + new_b = b.squeeze(-1) + else: + # matrix vector product: (..., m, 1) * (..., 1) -> (..., m) + # the last dimension b is already aligned for the elemwise multiplication + # after we squeeze the last dimension of a + new_a = a.squeeze(-1) + new_b = b + else: + return None + + else: + if a_st_shape[-1] == 1 or b_st_shape[-2] == 1: + if core_a_ndim == 1: + # vector_matrix product: (..., 1) * (..., 1, n) -> (..., n) + # the last dimension of a is already aligned for the elemwise multiplication + # after we squeeze the one to last dimension of b + new_a = a + new_b = b.squeeze(-2) + else: + # matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n) + # the dimensions of a and b are already aligned for the elemwise multiplication + new_a = a + new_b = b + else: + return None + + new_a = copy_stack_trace(a, new_a) + new_b = copy_stack_trace(b, new_b) + new_out = copy_stack_trace(node.out, mul(new_a, new_b)) + return [new_out] + + def is_inverse_pair(node_op, prev_op, inv_pair): """ Given two consecutive operations, check if they are the diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index debcf44c64..ab274a04f7 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -16,7 +16,8 @@ from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp, deep_copy_op from pytensor.configdefaults import config -from pytensor.graph.basic import Apply, equal_computations +from pytensor.graph import vectorize_graph +from pytensor.graph.basic import Apply, ancestors, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import ( SequentialNodeRewriter, @@ -4571,3 +4572,52 @@ def test_log_kv_stabilization(): out.eval({x: 1000.0}, mode=mode), -1003.2180912984705, ) + + +@pytest.mark.parametrize( + "a_shape,b_shape", + [ + ((1,), (1,)), + ((3, 1), (1,)), + ((1,), (1, 3)), + ((3, 1), (1, 3)), + ], +) +@pytest.mark.parametrize("batched", (False, True)) +def test_local_dot_to_mul(batched, a_shape, b_shape): + a = tensor("a", shape=a_shape) + b = tensor("b", shape=b_shape) + + out = dot(a, b) + if batched: + batch_a = tensor("batch_a", shape=(1, 5, *a_shape)) + batch_b = tensor("batch_b", shape=(7, 1, *b_shape)) + out = vectorize_graph(out, {a: batch_a, b: batch_b}) + a = batch_a + b = batch_b + + assert ( + sum( + isinstance(var.owner.op, (Blockwise | Dot)) + for var in ancestors([out]) + if var.owner + ) + == 1 + ) + + # For now rewrite only applies to Batched Dots + rewritten_out = rewrite_graph(out) + assert rewritten_out.type.shape == out.type.shape + assert sum( + isinstance(var.owner.op, (Blockwise | Dot)) + for var in ancestors([rewritten_out]) + if var.owner + ) == (0 if batched else 1) + + a_test = np.random.normal(size=a.type.shape).astype(a.type.dtype) + b_test = np.random.normal(size=b.type.shape).astype(b.type.dtype) + test_mode = Mode(linker="py", optimizer=None) + np.testing.assert_allclose( + out.eval({a: a_test, b: b_test}, mode=test_mode), + rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode), + )