Skip to content

Commit

Permalink
Slogdet returns naive expression and is optimized later (#1041)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanish1729 authored Nov 17, 2024
1 parent 33a4d48 commit bad8d20
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 54 deletions.
29 changes: 28 additions & 1 deletion pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor import TensorLike
from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm
from pytensor.tensor.basic import as_tensor_variable, diagonal
Expand Down Expand Up @@ -266,7 +267,33 @@ def __str__(self):
return "SLogDet"


slogdet = Blockwise(SLogDet())
def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]:
"""
Compute the sign and (natural) logarithm of the determinant of an array.
Returns a naive graph which is optimized later using rewrites with the det operation.
Parameters
----------
x : (..., M, M) tensor or tensor_like
Input tensor, has to be square.
Returns
-------
A tuple with the following attributes:
sign : (...) tensor_like
A number representing the sign of the determinant. For a real matrix,
this is 1, 0, or -1.
logabsdet : (...) tensor_like
The natural log of the absolute value of the determinant.
If the determinant is zero, then `sign` will be 0 and `logabsdet`
will be -inf. In all cases, the determinant is equal to
``sign * exp(logabsdet)``.
"""
det_val = det(x)
return ptm.sign(det_val), ptm.log(ptm.abs(det_val))


class Eig(Op):
Expand Down
120 changes: 72 additions & 48 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from collections.abc import Callable
from typing import cast

import numpy as np

from pytensor import Variable
from pytensor import tensor as pt
from pytensor.compile import optdb
Expand All @@ -11,7 +13,7 @@
in2out,
node_rewriter,
)
from pytensor.scalar.basic import Mul
from pytensor.scalar.basic import Abs, Log, Mul, Sign
from pytensor.tensor.basic import (
AllocDiag,
ExtractDiag,
Expand All @@ -30,11 +32,11 @@
KroneckerProduct,
MatrixInverse,
MatrixPinv,
SLogDet,
det,
inv,
kron,
pinv,
slogdet,
svd,
)
from pytensor.tensor.rewriting.basic import (
Expand Down Expand Up @@ -785,45 +787,6 @@ def rewrite_det_blockdiag(fgraph, node):
return [prod(det_sub_matrices)]


@register_canonicalize
@register_stabilize
@node_rewriter([slogdet])
def rewrite_slogdet_blockdiag(fgraph, node):
"""
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner block_diag operation
potential_block_diag = node.inputs[0].owner
if not (
potential_block_diag
and isinstance(potential_block_diag.op, Blockwise)
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
):
return None

# Find the composing sub_matrices
sub_matrices = potential_block_diag.inputs
sign_sub_matrices, logdet_sub_matrices = zip(
*[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
)

return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]


@register_canonicalize
@register_stabilize
@node_rewriter([ExtractDiag])
Expand Down Expand Up @@ -860,10 +823,10 @@ def rewrite_diag_kronecker(fgraph, node):

@register_canonicalize
@register_stabilize
@node_rewriter([slogdet])
def rewrite_slogdet_kronecker(fgraph, node):
@node_rewriter([det])
def rewrite_det_kronecker(fgraph, node):
"""
This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
This rewrite simplifies the determinant of a kronecker-structured matrix by extracting the individual sub matrices and returning the det values computed using those
Parameters
----------
Expand All @@ -884,13 +847,12 @@ def rewrite_slogdet_kronecker(fgraph, node):

# Find the matrices
a, b = potential_kron.inputs
signs, logdets = zip(*[slogdet(a), slogdet(b)])
dets = [det(a), det(b)]
sizes = [a.shape[-1], b.shape[-1]]
prod_sizes = prod(sizes, no_zeros_in_input=True)
signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
det_final = prod([dets[i] ** (prod_sizes / sizes[i]) for i in range(2)])

return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
return [det_final]


@register_canonicalize
Expand Down Expand Up @@ -989,3 +951,65 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply):
"jax",
position=0.9, # Run before canonicalization
)


@register_specialize
@node_rewriter([det])
def slogdet_specialization(fgraph, node):
"""
This rewrite targets specific operations related to slogdet i.e sign(det), log(det) and log(abs(det)) and rewrites them using the SLogDet operation.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
dictionary of Variables, optional
Dictionary of nodes and what they should be replaced with, or None if no optimization was performed
"""
dummy_replacements = {}
for client, _ in fgraph.clients[node.outputs[0]]:
# Check for sign(det)
if isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Sign):
dummy_replacements[client.outputs[0]] = "sign"

# Check for log(abs(det))
elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs):
potential_log = None
for client_2, _ in fgraph.clients[client.outputs[0]]:
if isinstance(client_2.op, Elemwise) and isinstance(
client_2.op.scalar_op, Log
):
potential_log = client_2
if potential_log:
dummy_replacements[potential_log.outputs[0]] = "log_abs_det"
else:
return None

# Check for log(det)
elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Log):
dummy_replacements[client.outputs[0]] = "log_det"

# Det is used directly for something else, don't rewrite to avoid computing two dets
else:
return None

if not dummy_replacements:
return None
else:
[x] = node.inputs
sign_det_x, log_abs_det_x = SLogDet()(x)
log_det_x = pt.where(pt.eq(sign_det_x, -1), np.nan, log_abs_det_x)
slogdet_specialization_map = {
"sign": sign_det_x,
"log_abs_det": log_abs_det_x,
"log_det": log_det_x,
}
replacements = {
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
}
return replacements
6 changes: 4 additions & 2 deletions tests/link/pytorch/test_nlinalg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Sequence

import numpy as np
import pytest

Expand All @@ -22,13 +24,13 @@ def matrix_test():

@pytest.mark.parametrize(
"func",
(pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.inv, pt_nla.det),
(pt_nla.eig, pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det),
)
def test_lin_alg_no_params(func, matrix_test):
x, test_value = matrix_test

out = func(x)
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])
out_fg = FunctionGraph([x], out if isinstance(out, Sequence) else [out])

def assert_fn(x, y):
np.testing.assert_allclose(x, y, rtol=1e-3)
Expand Down
96 changes: 93 additions & 3 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
KroneckerProduct,
MatrixInverse,
MatrixPinv,
SLogDet,
matrix_inverse,
svd,
)
Expand Down Expand Up @@ -719,7 +720,7 @@ def test_det_blockdiag_rewrite():


def test_slogdet_blockdiag_rewrite():
n_matrices = 100
n_matrices = 10
matrix_size = (5, 5)
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
Expand Down Expand Up @@ -776,11 +777,34 @@ def test_diag_kronecker_rewrite():
)


def test_det_kronecker_rewrite():
a, b = pt.dmatrices("a", "b")
kron_prod = pt.linalg.kron(a, b)
det_output = pt.linalg.det(kron_prod)
f_rewritten = function([a, b], [det_output], mode="FAST_RUN")

# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)

# Value Test
a_test, b_test = np.random.rand(2, 20, 20)
kron_prod_test = np.kron(a_test, b_test)
det_output_test = np.linalg.det(kron_prod_test)
rewritten_det_val = f_rewritten(a_test, b_test)
assert_allclose(
det_output_test,
rewritten_det_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_slogdet_kronecker_rewrite():
a, b = pt.dmatrices("a", "b")
kron_prod = pt.linalg.kron(a, b)
sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")
f_rewritten = function([a, b], [sign_output, logdet_output], mode="FAST_RUN")

# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
Expand All @@ -790,7 +814,7 @@ def test_slogdet_kronecker_rewrite():
a_test, b_test = np.random.rand(2, 20, 20)
kron_prod_test = np.kron(a_test, b_test)
sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
rewritten_sign_val, rewritten_logdet_val = f_rewritten(a_test, b_test)
assert_allclose(
sign_output_test,
rewritten_sign_val,
Expand Down Expand Up @@ -906,3 +930,69 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes)


def test_slogdet_specialization():
x, a = pt.dmatrix("x"), np.random.rand(20, 20)
det_x, det_a = pt.linalg.det(x), np.linalg.det(a)
log_abs_det_x, log_abs_det_a = pt.log(pt.abs(det_x)), np.log(np.abs(det_a))
log_det_x, log_det_a = pt.log(det_x), np.log(det_a)
sign_det_x, sign_det_a = pt.sign(det_x), np.sign(det_a)
exp_det_x = pt.exp(det_x)

# REWRITE TESTS
# sign(det(x))
f = function([x], [sign_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
assert not any(isinstance(node.op, Det) for node in nodes)
rw_sign_det_a = f(a)
assert_allclose(
sign_det_a,
rw_sign_det_a,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)

# log(abs(det(x)))
f = function([x], [log_abs_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
assert not any(isinstance(node.op, Det) for node in nodes)
rw_log_abs_det_a = f(a)
assert_allclose(
log_abs_det_a,
rw_log_abs_det_a,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)

# log(det(x))
f = function([x], [log_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
assert not any(isinstance(node.op, Det) for node in nodes)
rw_log_det_a = f(a)
assert_allclose(
log_det_a,
rw_log_det_a,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)

# More than 1 valid function
f = function([x], [sign_det_x, log_abs_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
assert not any(isinstance(node.op, Det) for node in nodes)

# Other functions (rewrite shouldnt be applied to these)
# Only invalid functions
f = function([x], [exp_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, SLogDet) for node in nodes)

# Invalid + Valid function
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, SLogDet) for node in nodes)

0 comments on commit bad8d20

Please sign in to comment.