diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index cd202fe3ed..0d4c694b23 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -1013,3 +1013,47 @@ def slogdet_specialization(fgraph, node): k: slogdet_specialization_map[v] for k, v in dummy_replacements.items() } return replacements + + +def _find_triangular_from_cholesky(potential_triangular): + if not ( + potential_triangular.owner is not None + and isinstance(potential_triangular.owner.op, Blockwise) + and isinstance(potential_triangular.owner.op.core_op, Cholesky) + ): + return None + + return potential_triangular + + +@register_canonicalize +@register_stabilize +@node_rewriter([det]) +def det_triangular_to_prod_diag(fgraph, node): + inputs = node.inputs[0] + triangular_check = _find_triangular_from_cholesky(inputs) + + if triangular_check: + det_val = inputs.diagonal().prod() + return [det_val] + + return None + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_inv_triangular_to_solve_triangular(fgraph, node): + core_op = node.op.core_op + if not (isinstance(core_op, ALL_INVERSE_OPS)): + return None + + inputs = node.inputs[0] + triangular_check = _find_triangular_from_cholesky(inputs) + + if triangular_check: + valid_eye = pt.eye(inputs.shape[-1]) + inv_val = solve_triangular(inputs, valid_eye, lower=True) + return [inv_val] + + return None diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index c9b9afff19..60fbdd1f9a 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -996,3 +996,64 @@ def test_slogdet_specialization(): f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN") nodes = f.maker.fgraph.apply_nodes assert not any(isinstance(node.op, SLogDet) for node in nodes) + + +def test_det_triangular(): + x = pt.matrix("x") + x_triangular = pt.linalg.cholesky(x) + z = pt.linalg.det(x_triangular) + + # Rewrite Test + f_rewritten = function([x], z, mode="FAST_RUN") + + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, Det) for node in nodes) + + # Numeric Test + x_test = np.random.rand(10, 10).astype(config.floatX) + x_psd = np.dot(x_test, x_test.T) + x_triangular = np.linalg.cholesky(x_psd) + det_val = np.linalg.det(x_triangular) + rewritten_val = f_rewritten(x_psd) + assert_allclose( + det_val, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + # Case where rewrite should not be applied + y = pt.matrix("y") + z = pt.linalg.det(y) + + f_rewritten = function([y], z, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert any(isinstance(node.op, Det) for node in nodes) + + +@pytest.mark.parametrize("inv_op", ["inv", "pinv"]) +def test_inv_triangular(inv_op): + x = pt.matrix("x") + x_triangular = pt.linalg.cholesky(x) + z = get_pt_function(x_triangular, inv_op) + + # Rewrite Test + f_rewritten = function([x], z, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + + valid_inverses = (MatrixInverse, MatrixPinv) + assert not any(isinstance(node.op, valid_inverses) for node in nodes) + + # Numeric Test + x_test = np.random.rand(10, 10).astype(config.floatX) + x_psd = np.dot(x_test, x_test.T) + x_triangular = np.linalg.cholesky(x_psd) + inv_val = np.linalg.inv(x_triangular) + rewritten_val = f_rewritten(x_psd) + + assert_allclose( + inv_val, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + )