Skip to content

Commit

Permalink
implement LU transform onnx export compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
fariedabuzaid committed Sep 27, 2023
1 parent 91d1409 commit c873221
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 20 deletions.
17 changes: 11 additions & 6 deletions src/veriflow/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@ def solve_triangular(M: torch.Tensor, y: torch.Tensor, pivot: Optional[int]=None
Returns:
(torch.Tensor): Solution of the system $Mx=y$
"""

dim = M.size(0)
if dim == 1:
return y / M

if pivot is None:
# Determine orientation of Matrix
if all([M[i, j] == 0 for i in range(M.size[0]) for j in range(i, M.size[1])]):
if all([M[i, j] == 0. for i in range(dim) for j in range(i+1, dim)]):
pivot = 0
elif all([M[i, j] == 0 for i in range(M.size[0]) for j in range(0, i+1)]):
elif all([M[i, j] == 0. for i in range(dim) for j in range(0, i)]):
pivot = -1
else:
raise ValueError("M needs to be triangular.")
Expand All @@ -31,15 +36,15 @@ def solve_triangular(M: torch.Tensor, y: torch.Tensor, pivot: Optional[int]=None
x = torch.zeros_like(y)
x[pivot] = y[pivot] / M[pivot, pivot]

y_next = (y - x[pivot] * L[:, pivot])
y_next = (y - x[pivot] * M[:, pivot])
if pivot == 0:
y_next = y_next[1:]
M_next = M[1:, 1:]
x[1:] = solve_triangular(y_next, M_next)
x[1:] = solve_triangular(M_next, y_next, pivot=pivot)
else:
y_next = y_next[:-1]
LMnext = M[:-1, :-1]
x[:-1] = solve_triangular(y_next, M_next)
M_next = M[:-1, :-1]
x[:-1] = solve_triangular(M_next, y_next, pivot=pivot)

return x

24 changes: 11 additions & 13 deletions src/veriflow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ def is_feasible(self) -> bool:
return (self.scale != 0).all()

def add_jitter(self, jitter: float = 1e-6) -> None:
"""Adds jitter to the diagonal elements of $\mathbf{U}$. This is useful to ensure that the transformation
is invertible."""
"""Adds jitter to the diagonal elements of $\mathbf{U}$."""
perturbation = torch.randn(self.dim, device=self.U_raw.device) * jitter
self.U_raw = self.scale + perturbation

Expand Down Expand Up @@ -201,8 +200,6 @@ def with_cache(self, cache_size: int = 1):
return Permute(self.permutation, cache_size=cache_size)




class LUTransform(BaseTransform):
"""Implementation of a linear bijection transform. Applies a transform $y = (\mathbf{L}\mathbf{U})^{-1}x$, where $\mathbf{L}$ is a
lower triangular matrix with unit diagonal and $\mathbf{U}$ is an upper triangular matrix. Bijectivity is guaranteed by
Expand Down Expand Up @@ -233,7 +230,7 @@ def __init__(self, dim: int, *args, **kwargs):

self.input_shape = dim

self.L_mask = torch.tril(torch.ones(dim, dim), diagonal=1)
self.L_mask = torch.tril(torch.ones(dim, dim), diagonal=-1)
self.U_mask = torch.triu(torch.ones(dim, dim), diagonal=0)

self.L_raw.register_hook(lambda grad: grad * self.L_mask)
Expand All @@ -246,7 +243,7 @@ def init_params(self):

init.kaiming_uniform_(self.L_raw, nonlinearity="relu")
with torch.no_grad():
self.L_raw.copy_(self.L_raw.tril(diagonal=1).fill_diagonal_(1))
self.L_raw.copy_(self.L_raw.tril(diagonal=-1).fill_diagonal_(1))

init.kaiming_uniform_(self.U_raw, nonlinearity="relu")
with torch.no_grad():
Expand All @@ -264,7 +261,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Ly_0 &= x + LU\textrm{bias} \\
Uy &= y_0
\end{align*}
:param x: input tensor
:type x: torch.Tensor
:return: transformed tensor $(LU)x + \mathrm{bias}$
Expand All @@ -284,7 +281,7 @@ def backward(self, y: torch.Tensor) -> torch.Tensor:
@property
def L(self) -> torch.Tensor:
"""The lower triangular matrix $\mathbf{L}$ of the layers LU decomposition"""
return self.L_raw.tril().fill_diagonal_(1)
return self.L_raw.tril(-1) + torch.eye(self.dim)

@property
def U(self) -> torch.Tensor:
Expand All @@ -309,11 +306,14 @@ def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> float:
Args:
x (torch.Tensor): input tensor
y (torch.Tensor): transformed tensor
Returns:
float: log absolute determinant of the Jacobian of the transform $(LU)x + \mathrm{bias}$
"""
return LA.slogdet(self.weight)[1]
U = self.U
dU = U - U.triu(1)
return dU.abs().log().sum()

def sign(self) -> int:
""" Computes the sign of the determinant of the Jacobian of the transform $(LU)x + \mathrm{bias}$.
Expand All @@ -324,7 +324,7 @@ def sign(self) -> int:
Returns:
float: sign of the determinant of the Jacobian of the transform $(LU)x + \mathrm{bias}$
"""
return LA.slogdet(self.weight)[0]
return self.L.diag().prod().sign() * self.U.diag().prod().sign()

def to(self, device) -> None:
""" Moves the layer to a given device
Expand Down Expand Up @@ -445,8 +445,6 @@ def to(self, device):
self.mask = self.mask.to(device)
return super().to(device)



class LeakyReLUTransform(BaseTransform):
bijective = True
domain = dist.constraints.real
Expand Down Expand Up @@ -504,4 +502,4 @@ def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor) -> float:
Returns:
float: log absolute determinant of the Jacobian of the transform
"""
return ((x <= 0).float() * math.log(self.alpha)).sum()
return math.log(y/x).sum()
3 changes: 2 additions & 1 deletion tests/onnx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
def test_onnx():
loc = torch.zeros(2)
scale = torch.ones(2)
model = NiceFlow(Normal(loc, scale), 2, [10, 10], split_dim=1, permutation="half")
model = NiceFlow(Normal(loc, scale), 2, [10, 10], split_dim=1, permutation="LU")
model.to_onnx("log_prob.onnx")
model.to_onnx("sample.onnx", export_mode="sample")

0 comments on commit c873221

Please sign in to comment.