-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
make LU forward efficient and Onnx compatible
- Loading branch information
1 parent
1990129
commit 91d1409
Showing
2 changed files
with
67 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from typing import Optional | ||
|
||
import torch | ||
|
||
|
||
def solve_triangular(M: torch.Tensor, y: torch.Tensor, pivot: Optional[int]=None) -> torch.Tensor: | ||
""" Re-implementation of torch solve_triangular. Since Onnx export of the native method is currently not supported, | ||
we implement it with basic torch operations. | ||
Args: | ||
M: triangular matrix. May be upper or lower triangular | ||
y: input vector. | ||
pivot: If given, determines wether to treat $M$ as a lower or upper triangular matrix. Note that in this case | ||
there is no check wether $M$ is actually lower or upper triangular, respectively. It therefore | ||
speeds up computation but should be used with care. | ||
Returns: | ||
(torch.Tensor): Solution of the system $Mx=y$ | ||
""" | ||
if pivot is None: | ||
# Determine orientation of Matrix | ||
if all([M[i, j] == 0 for i in range(M.size[0]) for j in range(i, M.size[1])]): | ||
pivot = 0 | ||
elif all([M[i, j] == 0 for i in range(M.size[0]) for j in range(0, i+1)]): | ||
pivot = -1 | ||
else: | ||
raise ValueError("M needs to be triangular.") | ||
elif pivot not in [0, -1]: | ||
raise ValueError("pivot needs to be either None, 0, or -1.") | ||
|
||
|
||
x = torch.zeros_like(y) | ||
x[pivot] = y[pivot] / M[pivot, pivot] | ||
|
||
y_next = (y - x[pivot] * L[:, pivot]) | ||
if pivot == 0: | ||
y_next = y_next[1:] | ||
M_next = M[1:, 1:] | ||
x[1:] = solve_triangular(y_next, M_next) | ||
else: | ||
y_next = y_next[:-1] | ||
LMnext = M[:-1, :-1] | ||
x[:-1] = solve_triangular(y_next, M_next) | ||
|
||
return x | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters