Skip to content

Commit

Permalink
make LU forward efficient and Onnx compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
fariedabuzaid committed Sep 27, 2023
1 parent 1990129 commit 91d1409
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 15 deletions.
45 changes: 45 additions & 0 deletions src/veriflow/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Optional

import torch


def solve_triangular(M: torch.Tensor, y: torch.Tensor, pivot: Optional[int]=None) -> torch.Tensor:
""" Re-implementation of torch solve_triangular. Since Onnx export of the native method is currently not supported,
we implement it with basic torch operations.
Args:
M: triangular matrix. May be upper or lower triangular
y: input vector.
pivot: If given, determines wether to treat $M$ as a lower or upper triangular matrix. Note that in this case
there is no check wether $M$ is actually lower or upper triangular, respectively. It therefore
speeds up computation but should be used with care.
Returns:
(torch.Tensor): Solution of the system $Mx=y$
"""
if pivot is None:
# Determine orientation of Matrix
if all([M[i, j] == 0 for i in range(M.size[0]) for j in range(i, M.size[1])]):
pivot = 0
elif all([M[i, j] == 0 for i in range(M.size[0]) for j in range(0, i+1)]):
pivot = -1
else:
raise ValueError("M needs to be triangular.")
elif pivot not in [0, -1]:
raise ValueError("pivot needs to be either None, 0, or -1.")


x = torch.zeros_like(y)
x[pivot] = y[pivot] / M[pivot, pivot]

y_next = (y - x[pivot] * L[:, pivot])
if pivot == 0:
y_next = y_next[1:]
M_next = M[1:, 1:]
x[1:] = solve_triangular(y_next, M_next)
else:
y_next = y_next[:-1]
LMnext = M[:-1, :-1]
x[:-1] = solve_triangular(y_next, M_next)

return x

37 changes: 22 additions & 15 deletions src/veriflow/transforms.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from abc import abstractmethod
import math
from abc import abstractmethod
from typing import List

import numpy as np
import pyro
import torch
from pyro import distributions as dist
from pyro.distributions import constraints
from pyro.distributions.transforms import AffineCoupling, LowerCholeskyAffine, Permute
from pyro.distributions.transforms import (AffineCoupling, LowerCholeskyAffine,
Permute)
from pyro.infer import SVI
from pyro.nn import DenseNN
from sklearn.datasets import load_digits
Expand All @@ -18,6 +19,8 @@
from torch.nn import init
from tqdm import tqdm

from src.veriflow.linalg import solve_triangular


class BaseTransform(dist.TransformModule):
"""Base class for transforms. Implemented as a thin layer on top of pyro's TransformModule. The baseTransform
Expand Down Expand Up @@ -180,9 +183,9 @@ def _inverse(self, y: torch.Tensor):

def log_abs_det_jacobian(self, x: torch.Tensor, y: torch.Tensor):
"""
Calculates the elementwise determinant of the log Jacobian, i.e.
Calculates the element-wise determinant of the log Jacobian, i.e.
log(abs([dy_0/dx_0, ..., dy_{N-1}/dx_{N-1}])). Note that this type of
transform is not autoregressive, so the log Jacobian is not the sum of the
transform is not auto-regressive, so the log Jacobian is not the sum of the
previous expression. However, it turns out it's always 0 (since the
determinant is -1 or +1), and so returning a vector of zeros works.
"""
Expand All @@ -201,7 +204,7 @@ def with_cache(self, cache_size: int = 1):


class LUTransform(BaseTransform):
"""Implementation of a linear bijection transform. Applies a transform $y = \mathbf{L}\mathbf{U}x$, where $\mathbf{L}$ is a
"""Implementation of a linear bijection transform. Applies a transform $y = (\mathbf{L}\mathbf{U})^{-1}x$, where $\mathbf{L}$ is a
lower triangular matrix with unit diagonal and $\mathbf{U}$ is an upper triangular matrix. Bijectivity is guaranteed by
requiring that the diagonal elements of $\mathbf{U}$ are positive and the diagonal elements of $\mathbf{L}$ are all $1$.
Expand Down Expand Up @@ -255,16 +258,23 @@ def init_params(self):
init.uniform_(self.bias, -bound, bound)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Computes the affine transform $(LU)x + \mathrm{bias}$
"""Computes the affine transform $y = (LU)^{-1}x + \mathrm{bias}$.
The value $y$ is computed by solving the linear equation system
\begin{align*}
Ly_0 &= x + LU\textrm{bias} \\
Uy &= y_0
\end{align*}
:param x: input tensor
:type x: torch.Tensor
:return: transformed tensor $(LU)x + \mathrm{bias}$
"""
return F.linear(x, self.weight, self.bias)
x0 = x + torch.functional.F.linear(self.bias, self.inv_weight)
y0 = solve_triangular(self.L, x0)
return solve_triangular(self.U, y0)

def backward(self, y: torch.Tensor) -> torch.Tensor:
"""Computes the inverse transform $(LU)^{-1}(y - \mathrm{bias})$
"""Computes the inverse transform $(LU)(y - \mathrm{bias})$
:param y: input tensor
:type y: torch.Tensor
Expand All @@ -286,11 +296,6 @@ def inv_weight(self) -> torch.Tensor:
"""Inverse weight matrix of the affine transform"""
return LA.matmul(self.L, self.U)

@property
def weight(self) -> torch.Tensor:
"""Weight matrix of the affine transform"""
return LA.inv(LA.matmul(self.L, self.U))

def _call(self, x: torch.Tensor) -> torch.Tensor:
""" Alias for :func:`forward`"""
return self.forward(x)
Expand Down Expand Up @@ -349,6 +354,7 @@ def add_jitter(self, jitter: float = 1e-6) -> None:
self.U_raw
+ perturbation * torch.eye(self.dim, device=self.U_raw.device)
)



class MaskedCoupling(BaseTransform):
Expand Down Expand Up @@ -438,7 +444,8 @@ def to(self, device):
"""
self.mask = self.mask.to(device)
return super().to(device)




class LeakyReLUTransform(BaseTransform):
bijective = True
Expand Down

0 comments on commit 91d1409

Please sign in to comment.