Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement eig.eigh_partial and allow parametric matvecs in partial SVD & eig #233

Merged
merged 7 commits into from
Jan 7, 2025
4 changes: 4 additions & 0 deletions matfree/backend/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def eigh(x, /):
return jnp.linalg.eigh(x)


def eig(x, /):
return jnp.linalg.eig(x)


def cholesky(x, /):
return jnp.linalg.cholesky(x)

Expand Down
31 changes: 29 additions & 2 deletions matfree/eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ def svd_partial(bidiag: Callable):

"""

def svd(Av: Callable, v0: Array):
def svd(Av: Callable, v0: Array, *parameters):
# Factorise the matrix
(u, v), B, *_ = bidiag(Av, v0)
(u, v), B, *_ = bidiag(Av, v0, *parameters)

# Compute SVD of factorisation
U, S, Vt = linalg.svd(B, full_matrices=False)
Expand All @@ -31,3 +31,30 @@ def svd(Av: Callable, v0: Array):
return u @ U, S, Vt @ v.T

return svd


def eigh_partial(tridiag_sym: Callable):
"""Partial symmetric/Hermitian eigenvalue decomposition.

Combines tridiagonalization with a decomposition
of the (small) tridiagonal matrix.

Parameters
----------
tridiag_sym:
An implementation of tridiagonalization.
For example, the output of
[decomp.tridiag_sym][matfree.decomp.tridiag_sym].

"""

def eigh(Av: Callable, v0: Array, *parameters):
# Factorise the matrix
Q, H, *_ = tridiag_sym(Av, v0, *parameters)

# Compute SVD of factorisation
vals, vecs = linalg.eigh(H)
vecs = Q @ vecs
return vals, vecs

return eigh
38 changes: 38 additions & 0 deletions tests/test_eig/test_eigh_partial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""Tests for eigenvalue functionality."""

from matfree import decomp, eig, test_util
from matfree.backend import linalg, np, testing


@testing.fixture()
@testing.parametrize("nrows", [10])
def A(nrows):
"""Make a positive definite matrix with certain spectrum."""
# 'Invent' a spectrum. Use the number of pre-defined eigenvalues.
d = np.arange(1.0, 1.0 + nrows)
return test_util.asymmetric_matrix_from_singular_values(d, nrows=nrows, ncols=nrows)


def test_equal_to_linalg_eigh(A):
"""The output of full-depth decomposition should be equal (*) to linalg.svd().

(*) Note: The singular values should be identical,
and the orthogonal matrices should be orthogonal. They are not unique.
"""
nrows, ncols = np.shape(A)
num_matvecs = min(nrows, ncols)

v0 = np.ones((ncols,))
v0 /= linalg.vector_norm(v0)

hessenberg = decomp.hessenberg(num_matvecs, reortho="full")
alg = eig.eigh_partial(hessenberg)
S, U = alg(lambda v, p: p @ v, v0, A)

S_, U_ = linalg.eigh(A)

assert np.allclose(S, S_)
assert np.shape(U) == np.shape(U_)

tols_decomp = {"atol": 1e-5, "rtol": 1e-5}
assert np.allclose(U @ U.T, U_ @ U_.T, **tols_decomp)
5 changes: 1 addition & 4 deletions tests/test_eig/test_svd_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,12 @@ def test_equal_to_linalg_svd(A):
nrows, ncols = np.shape(A)
num_matvecs = min(nrows, ncols)

def Av(v):
return A @ v

v0 = np.ones((ncols,))
v0 /= linalg.vector_norm(v0)

bidiag = decomp.bidiag(num_matvecs)
svd = eig.svd_partial(bidiag)
U, S, Vt = svd(Av, v0)
U, S, Vt = svd(lambda v, p: p @ v, v0, A)

U_, S_, Vt_ = linalg.svd(A, full_matrices=False)

Expand Down
Loading