Skip to content

Commit

Permalink
Renamed decomposition algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
pnkraemer committed Nov 21, 2023
1 parent 2210962 commit ed732c4
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 15 deletions.
10 changes: 5 additions & 5 deletions matfree/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class _Alg(containers.NamedTuple):
"""Range of the for-loop used to decompose a matrix."""


def tridiagonal_full_reortho(depth, /):
def lanczos_tridiag_full_reortho(depth, /):
"""Construct an implementation of **tridiagonalisation**.
Uses pre-allocation. Fully reorthogonalise vectors at every step.
Expand Down Expand Up @@ -83,7 +83,7 @@ def extract(state: State, /):
return _Alg(init=init, step=apply, extract=extract, lower_upper=(0, depth + 1))


def bidiagonal_full_reortho(depth, /, matrix_shape):
def lanczos_bidiag_full_reortho(depth, /, matrix_shape):
"""Construct an implementation of **bidiagonalisation**.
Uses pre-allocation. Fully reorthogonalise vectors at every step.
Expand Down Expand Up @@ -161,7 +161,7 @@ def _gram_schmidt_orthogonalise(vec1, vec2):
return vec_ortho, coeff


def svd(
def svd_approx(
v0: Array, depth: int, Av: Callable, vA: Callable, matrix_shape: Tuple[int, ...]
):
"""Approximate singular value decomposition.
Expand All @@ -185,7 +185,7 @@ def svd(
Shape of the matrix involved in matrix-vector and vector-matrix products.
"""
# Factorise the matrix
algorithm = bidiagonal_full_reortho(depth, matrix_shape=matrix_shape)
algorithm = lanczos_bidiag_full_reortho(depth, matrix_shape=matrix_shape)
u, (d, e), vt, _ = decompose_fori_loop(v0, Av, vA, algorithm=algorithm)

# Compute SVD of factorisation
Expand Down Expand Up @@ -222,7 +222,7 @@ class _DecompAlg(containers.NamedTuple):
"""Decomposition algorithm type.
For example, the output of
[matfree.decomp.tridiagonal_full_reortho(...)][matfree.decomp.tridiagonal_full_reortho].
[matfree.decomp.lanczos_tridiag_full_reortho(...)][matfree.decomp.lanczos_tridiag_full_reortho].
"""


Expand Down
4 changes: 2 additions & 2 deletions matfree/slq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _quadratic_form_slq_spd(matfun, order, Av, /):
"""

def quadform(v0, /):
algorithm = decomp.tridiagonal_full_reortho(order)
algorithm = decomp.lanczos_tridiag_full_reortho(order)
_, tridiag = decomp.decompose_fori_loop(v0, Av, algorithm=algorithm)
(diag, off_diag) = tridiag

Expand Down Expand Up @@ -85,7 +85,7 @@ def _quadratic_form_slq_product(matfun, depth, *matvec_funs, matrix_shape):

def quadform(v0, /):
# Decompose into orthogonal-bidiag-orthogonal
algorithm = decomp.bidiagonal_full_reortho(depth, matrix_shape=matrix_shape)
algorithm = decomp.lanczos_bidiag_full_reortho(depth, matrix_shape=matrix_shape)
output = decomp.decompose_fori_loop(v0, *matvec_funs, algorithm=algorithm)
u, (d, e), vt, _ = output

Expand Down
10 changes: 5 additions & 5 deletions tests/test_decomp/test_bidiagonal_full_reortho.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ def A(nrows, ncols, num_significant_singular_vals):
@testing.parametrize("ncols", [49])
@testing.parametrize("num_significant_singular_vals", [4])
@testing.parametrize("order", [6]) # ~1.5 * num_significant_eigvals
def test_bidiagonal_full_reortho(A, order):
def test_lanczos_bidiag_full_reortho(A, order):
"""Test that Lanczos tridiagonalisation yields an orthogonal-tridiagonal decomp."""
nrows, ncols = np.shape(A)
key = prng.prng_key(1)
v0 = prng.normal(key, shape=(ncols,))
alg = decomp.bidiagonal_full_reortho(order, matrix_shape=np.shape(A))
alg = decomp.lanczos_bidiag_full_reortho(order, matrix_shape=np.shape(A))

def Av(v):
return A @ v
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_error_too_high_depth(A):
max_depth = min(nrows, ncols) - 1

with testing.raises(ValueError):
_ = decomp.bidiagonal_full_reortho(max_depth + 1, matrix_shape=np.shape(A))
_ = decomp.lanczos_bidiag_full_reortho(max_depth + 1, matrix_shape=np.shape(A))


@testing.parametrize("nrows", [5])
Expand All @@ -84,7 +84,7 @@ def test_error_too_low_depth(A):
"""Assert that a ValueError is raised when the depth is negative."""
min_depth = 0
with testing.raises(ValueError):
_ = decomp.bidiagonal_full_reortho(min_depth - 1, matrix_shape=np.shape(A))
_ = decomp.lanczos_bidiag_full_reortho(min_depth - 1, matrix_shape=np.shape(A))


@testing.parametrize("nrows", [15])
Expand All @@ -93,7 +93,7 @@ def test_error_too_low_depth(A):
def test_no_error_zero_depth(A):
"""Assert the corner case of zero-depth does not raise an error."""
nrows, ncols = np.shape(A)
algorithm = decomp.bidiagonal_full_reortho(0, matrix_shape=np.shape(A))
algorithm = decomp.lanczos_bidiag_full_reortho(0, matrix_shape=np.shape(A))
key = prng.prng_key(1)
v0 = prng.normal(key, shape=(ncols,))

Expand Down
2 changes: 1 addition & 1 deletion tests/test_decomp/test_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def vA(v):
return v @ A

v0 = np.ones((ncols,))
U, S, Vt = decomp.svd(v0, depth, Av, vA, matrix_shape=np.shape(A))
U, S, Vt = decomp.svd_approx(v0, depth, Av, vA, matrix_shape=np.shape(A))
U_, S_, Vt_ = linalg.svd(A, full_matrices=False)

tols_decomp = {"atol": 1e-5, "rtol": 1e-5}
Expand Down
4 changes: 2 additions & 2 deletions tests/test_decomp/test_tridiagonal_full_reortho.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_max_order(A):
order = n - 1
key = prng.prng_key(1)
v0 = prng.normal(key, shape=(n,))
alg = decomp.tridiagonal_full_reortho(order)
alg = decomp.lanczos_tridiag_full_reortho(order)
Q, (d_m, e_m) = decomp.decompose_fori_loop(v0, lambda v: A @ v, algorithm=alg)

# Lanczos is not stable.
Expand Down Expand Up @@ -63,7 +63,7 @@ def test_identity(A, order):
n, _ = np.shape(A)
key = prng.prng_key(1)
v0 = prng.normal(key, shape=(n,))
alg = decomp.tridiagonal_full_reortho(order)
alg = decomp.lanczos_tridiag_full_reortho(order)
Q, tridiag = decomp.decompose_fori_loop(v0, lambda v: A @ v, algorithm=alg)
(d_m, e_m) = tridiag

Expand Down

0 comments on commit ed732c4

Please sign in to comment.