Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrites for triangular matrices #1131

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,3 +1013,47 @@ def slogdet_specialization(fgraph, node):
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
}
return replacements


def _find_triangular_from_cholesky(potential_triangular):
if not (
potential_triangular.owner is not None
and isinstance(potential_triangular.owner.op, Blockwise)
and isinstance(potential_triangular.owner.op.core_op, Cholesky)
):
return None

return potential_triangular


@register_canonicalize
@register_stabilize
@node_rewriter([det])
def det_triangular_to_prod_diag(fgraph, node):
inputs = node.inputs[0]
triangular_check = _find_triangular_from_cholesky(inputs)

if triangular_check:
det_val = inputs.diagonal().prod()
return [det_val]

return None


@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_inv_triangular_to_solve_triangular(fgraph, node):
core_op = node.op.core_op
if not (isinstance(core_op, ALL_INVERSE_OPS)):
return None

inputs = node.inputs[0]
triangular_check = _find_triangular_from_cholesky(inputs)

if triangular_check:
valid_eye = pt.eye(inputs.shape[-1])
inv_val = solve_triangular(inputs, valid_eye, lower=True)
return [inv_val]

return None
61 changes: 61 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,3 +996,64 @@ def test_slogdet_specialization():
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, SLogDet) for node in nodes)


def test_det_triangular():
x = pt.matrix("x")
x_triangular = pt.linalg.cholesky(x)
z = pt.linalg.det(x_triangular)

# Rewrite Test
f_rewritten = function([x], z, mode="FAST_RUN")

nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Det) for node in nodes)

# Numeric Test
x_test = np.random.rand(10, 10).astype(config.floatX)
x_psd = np.dot(x_test, x_test.T)
x_triangular = np.linalg.cholesky(x_psd)
det_val = np.linalg.det(x_triangular)
rewritten_val = f_rewritten(x_psd)
assert_allclose(
det_val,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)

# Case where rewrite should not be applied
y = pt.matrix("y")
z = pt.linalg.det(y)

f_rewritten = function([y], z, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Det) for node in nodes)


@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
def test_inv_triangular(inv_op):
x = pt.matrix("x")
x_triangular = pt.linalg.cholesky(x)
z = get_pt_function(x_triangular, inv_op)

# Rewrite Test
f_rewritten = function([x], z, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes

valid_inverses = (MatrixInverse, MatrixPinv)
assert not any(isinstance(node.op, valid_inverses) for node in nodes)

# Numeric Test
x_test = np.random.rand(10, 10).astype(config.floatX)
x_psd = np.dot(x_test, x_test.T)
x_triangular = np.linalg.cholesky(x_psd)
inv_val = np.linalg.inv(x_triangular)
rewritten_val = f_rewritten(x_psd)

assert_allclose(
inv_val,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
Loading