Skip to content

Commit

Permalink
Wrote individual FillTriL, DiagTransform. Wrote composition FillScale…
Browse files Browse the repository at this point in the history
…TriL
  • Loading branch information
LuisA92 committed May 7, 2024
1 parent e4ee292 commit 70e48f3
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 72 deletions.
4 changes: 3 additions & 1 deletion src/rs_distributions/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from .fill_scale_tril import FillScaleTriL
from .fill_scale_tril import FillScaleTriL, FillTriL, DiagTransform

__all__ = [
"FillScaleTriL",
"FillTriL",
"DiagTransform",
]
112 changes: 47 additions & 65 deletions src/rs_distributions/transforms/fill_scale_tril.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,74 @@
import torch
from torch.distributions import Transform, constraints
from torch.distributions import Transform, ComposeTransform, constraints
from torch.distributions.transforms import SoftplusTransform
from torch.distributions.utils import vec_to_tril_matrix, tril_matrix_to_vec


class FillScaleTriL(Transform):
def __init__(self, diag_transform=None, diag_shift=1e-06):
"""
Converts a tensor into a lower triangular matrix with positive diagonal entries.
Args:
diag_transform: transformation used on diagonal to ensure positive values.
Default is SoftplusTransform
diag_shift (float): small offset to avoid diagonals very close to zero.
Default offset is 1e-06
"""
class FillTriL(Transform):
def __init__(self):
super().__init__()
self.diag_transform = (
diag_transform if diag_transform is not None else SoftplusTransform()
)
self.diag_shift = diag_shift

@property
def domain(self):
return constraints.real_vector

@property
def codomain(self):
return constraints.lower_cholesky
return constraints.lower_triangular

@property
def bijective(self):
return True

def _call(self, x):
"""
Transform input vector to lower triangular.
Args:
x (torch.Tensor): Input vector to transform
Returns:
torch.Tensor: Transformed lower triangular matrix
"""
x = vec_to_tril_matrix(x)
diagonal = x.diagonal(dim1=-2, dim2=-1)
if self.diag_shift is not None:
result = x.diagonal_scatter(
self.diag_transform(diagonal + self.diag_shift), dim1=-2, dim2=-1
)
else:
result = x.diagonal_scatter(self.diag_transform(diagonal), dim1=-2, dim2=-1)
return result
return vec_to_tril_matrix(x)

def _inverse(self, y):
"""
Apply the inverse transformation to the input lower triangular matrix.
return tril_matrix_to_vec(y)

def log_abs_det_jacobian(self, x, y):
return torch.zeros(x.shape[0], dtype=x.dtype, device=x.device)


class DiagTransform(Transform):
def __init__(self, diag_transform):
super().__init__()
self.diag_transform = diag_transform

@property
def domain(self):
return self.diag_transform.domain

@property
def codomain(self):
return self.diag_transform.codomain

Args:
y (torch.Tensor): Invertible lower triangular matrix
@property
def bijective(self):
return self.diag_transform.bijective

def _call(self, x):
diagonal = x.diagonal(dim1=-2, dim2=-1)
transformed_diagonal = self.diag_transform(diagonal)
shifted_diag = transformed_diagonal
result = x.diagonal_scatter(shifted_diag, dim1=-2, dim2=-1)

Returns:
torch.Tensor: Inversely transformed vector
return result

"""
def _inverse(self, y):
diagonal = y.diagonal(dim1=-2, dim2=-1)
if self.diag_shift is not None:
result = y.diagonal_scatter(
self.diag_transform.inv(diagonal - self.diag_shift), dim1=-2, dim2=-1
)
else:
result = y.diagonal_scatter(
self.diag_transform.inv(diagonal), dim1=-2, dim2=-1
)
return tril_matrix_to_vec(result)
result = y.diagonal_scatter(self.diag_transform.inv(diagonal), dim1=-2, dim2=-1)
return result

def log_abs_det_jacobian(self, x, y):
L = vec_to_tril_matrix(x)
diag = L.diagonal(dim1=-2, dim2=-1)
diag.requires_grad_(True)
if self.diag_shift is not None:
transformed_diag = self.diag_transform(diag + self.diag_shift)
else:
transformed_diag = self.diag_transform(diag)
derivatives = torch.autograd.grad(
outputs=transformed_diag,
inputs=diag,
grad_outputs=torch.ones_like(transformed_diag),
)[0]
log_det_jacobian = torch.log(torch.abs(derivatives)).sum()
return log_det_jacobian
diagonal = x.diagonal(dim1=-2, dim2=-1)
return self.diag_transform.log_abs_det_jacobian(diagonal, y)


class FillScaleTriL(ComposeTransform):
def __init__(self, diag_transform=SoftplusTransform()):
super().__init__([FillTriL(), DiagTransform(diag_transform=diag_transform)])

@property
def bijective(self):
return True
12 changes: 6 additions & 6 deletions tests/transforms/fill_scale_tril.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pytest
from rs_distributions.transforms.fill_scale_tril import FillScaleTriL
from rs_distributions.transforms.fill_scale_tril import (
FillScaleTriL,
)
import torch
from torch.distributions.constraints import lower_cholesky

Expand All @@ -9,7 +11,7 @@ def test_forward_transform(batch_shape, d):
transform = FillScaleTriL()
input_shape = batch_shape + (d,)
input_vector = torch.randn(input_shape)
transformed_vector = transform._call(input_vector)
transformed_vector = transform(input_vector)

n = int((-1 + torch.sqrt(torch.tensor(1 + 8 * d))) / 2)
expected_output_shape = batch_shape + (n, n)
Expand All @@ -27,10 +29,8 @@ def test_forward_equals_inverse(batch_shape, d):
transform = FillScaleTriL()
input_shape = batch_shape + (d,)
input_vector = torch.randn(input_shape)
L = transform._call(input_vector)
invL = transform._inverse(L)

n = int((-1 + torch.sqrt(torch.tensor(1 + 8 * d))) / 2)
L = transform(input_vector)
invL = transform.inv(L)

assert torch.allclose(
input_vector, invL, atol=1e-4
Expand Down

0 comments on commit 70e48f3

Please sign in to comment.