Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Renamed decomposition algorithms #162

Merged
merged 1 commit into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions matfree/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class _Alg(containers.NamedTuple):
"""Range of the for-loop used to decompose a matrix."""


def tridiagonal_full_reortho(depth, /):
def lanczos_tridiag_full_reortho(depth, /):
"""Construct an implementation of **tridiagonalisation**.

Uses pre-allocation. Fully reorthogonalise vectors at every step.
Expand Down Expand Up @@ -83,7 +83,7 @@ def extract(state: State, /):
return _Alg(init=init, step=apply, extract=extract, lower_upper=(0, depth + 1))


def bidiagonal_full_reortho(depth, /, matrix_shape):
def lanczos_bidiag_full_reortho(depth, /, matrix_shape):
"""Construct an implementation of **bidiagonalisation**.

Uses pre-allocation. Fully reorthogonalise vectors at every step.
Expand Down Expand Up @@ -161,7 +161,7 @@ def _gram_schmidt_orthogonalise(vec1, vec2):
return vec_ortho, coeff


def svd(
def svd_approx(
v0: Array, depth: int, Av: Callable, vA: Callable, matrix_shape: Tuple[int, ...]
):
"""Approximate singular value decomposition.
Expand All @@ -185,7 +185,7 @@ def svd(
Shape of the matrix involved in matrix-vector and vector-matrix products.
"""
# Factorise the matrix
algorithm = bidiagonal_full_reortho(depth, matrix_shape=matrix_shape)
algorithm = lanczos_bidiag_full_reortho(depth, matrix_shape=matrix_shape)
u, (d, e), vt, _ = decompose_fori_loop(v0, Av, vA, algorithm=algorithm)

# Compute SVD of factorisation
Expand Down Expand Up @@ -222,7 +222,7 @@ class _DecompAlg(containers.NamedTuple):
"""Decomposition algorithm type.

For example, the output of
[matfree.decomp.tridiagonal_full_reortho(...)][matfree.decomp.tridiagonal_full_reortho].
[matfree.decomp.lanczos_tridiag_full_reortho(...)][matfree.decomp.lanczos_tridiag_full_reortho].
"""


Expand Down
4 changes: 2 additions & 2 deletions matfree/slq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _quadratic_form_slq_spd(matfun, order, Av, /):
"""

def quadform(v0, /):
algorithm = decomp.tridiagonal_full_reortho(order)
algorithm = decomp.lanczos_tridiag_full_reortho(order)
_, tridiag = decomp.decompose_fori_loop(v0, Av, algorithm=algorithm)
(diag, off_diag) = tridiag

Expand Down Expand Up @@ -85,7 +85,7 @@ def _quadratic_form_slq_product(matfun, depth, *matvec_funs, matrix_shape):

def quadform(v0, /):
# Decompose into orthogonal-bidiag-orthogonal
algorithm = decomp.bidiagonal_full_reortho(depth, matrix_shape=matrix_shape)
algorithm = decomp.lanczos_bidiag_full_reortho(depth, matrix_shape=matrix_shape)
output = decomp.decompose_fori_loop(v0, *matvec_funs, algorithm=algorithm)
u, (d, e), vt, _ = output

Expand Down
10 changes: 5 additions & 5 deletions tests/test_decomp/test_bidiagonal_full_reortho.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ def A(nrows, ncols, num_significant_singular_vals):
@testing.parametrize("ncols", [49])
@testing.parametrize("num_significant_singular_vals", [4])
@testing.parametrize("order", [6]) # ~1.5 * num_significant_eigvals
def test_bidiagonal_full_reortho(A, order):
def test_lanczos_bidiag_full_reortho(A, order):
"""Test that Lanczos tridiagonalisation yields an orthogonal-tridiagonal decomp."""
nrows, ncols = np.shape(A)
key = prng.prng_key(1)
v0 = prng.normal(key, shape=(ncols,))
alg = decomp.bidiagonal_full_reortho(order, matrix_shape=np.shape(A))
alg = decomp.lanczos_bidiag_full_reortho(order, matrix_shape=np.shape(A))

def Av(v):
return A @ v
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_error_too_high_depth(A):
max_depth = min(nrows, ncols) - 1

with testing.raises(ValueError):
_ = decomp.bidiagonal_full_reortho(max_depth + 1, matrix_shape=np.shape(A))
_ = decomp.lanczos_bidiag_full_reortho(max_depth + 1, matrix_shape=np.shape(A))


@testing.parametrize("nrows", [5])
Expand All @@ -84,7 +84,7 @@ def test_error_too_low_depth(A):
"""Assert that a ValueError is raised when the depth is negative."""
min_depth = 0
with testing.raises(ValueError):
_ = decomp.bidiagonal_full_reortho(min_depth - 1, matrix_shape=np.shape(A))
_ = decomp.lanczos_bidiag_full_reortho(min_depth - 1, matrix_shape=np.shape(A))


@testing.parametrize("nrows", [15])
Expand All @@ -93,7 +93,7 @@ def test_error_too_low_depth(A):
def test_no_error_zero_depth(A):
"""Assert the corner case of zero-depth does not raise an error."""
nrows, ncols = np.shape(A)
algorithm = decomp.bidiagonal_full_reortho(0, matrix_shape=np.shape(A))
algorithm = decomp.lanczos_bidiag_full_reortho(0, matrix_shape=np.shape(A))
key = prng.prng_key(1)
v0 = prng.normal(key, shape=(ncols,))

Expand Down
2 changes: 1 addition & 1 deletion tests/test_decomp/test_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def vA(v):
return v @ A

v0 = np.ones((ncols,))
U, S, Vt = decomp.svd(v0, depth, Av, vA, matrix_shape=np.shape(A))
U, S, Vt = decomp.svd_approx(v0, depth, Av, vA, matrix_shape=np.shape(A))
U_, S_, Vt_ = linalg.svd(A, full_matrices=False)

tols_decomp = {"atol": 1e-5, "rtol": 1e-5}
Expand Down
4 changes: 2 additions & 2 deletions tests/test_decomp/test_tridiagonal_full_reortho.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_max_order(A):
order = n - 1
key = prng.prng_key(1)
v0 = prng.normal(key, shape=(n,))
alg = decomp.tridiagonal_full_reortho(order)
alg = decomp.lanczos_tridiag_full_reortho(order)
Q, (d_m, e_m) = decomp.decompose_fori_loop(v0, lambda v: A @ v, algorithm=alg)

# Lanczos is not stable.
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_identity(A, order):
n, _ = np.shape(A)
key = prng.prng_key(1)
v0 = prng.normal(key, shape=(n,))
alg = decomp.tridiagonal_full_reortho(order)
alg = decomp.lanczos_tridiag_full_reortho(order)
Q, tridiag = decomp.decompose_fori_loop(v0, lambda v: A @ v, algorithm=alg)
(d_m, e_m) = tridiag

Expand Down
Loading