Skip to content

Commit

Permalink
Change ordering of the returned vecs in eig.eig_partial
Browse files Browse the repository at this point in the history
  • Loading branch information
pnkraemer committed Jan 8, 2025
1 parent 08696db commit 562ab36
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 28 deletions.
9 changes: 8 additions & 1 deletion matfree/backend/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@ def eigh(x, /):


def eig(x, /):
return jnp.linalg.eig(x)
vals, vecs = jnp.linalg.eig(x)

# Order the values and vectors by magnitude
# (jax.numpy.linalg.eig does not do this)
ordered = jnp.argsort(vals)[::-1]
vals = vals[ordered]
vecs = vecs[:, ordered]
return vals, vecs


def cholesky(x, /):
Expand Down
10 changes: 2 additions & 8 deletions matfree/eig.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Matrix-free eigenvalue and singular-value analysis."""

from matfree.backend import linalg, np
from matfree.backend import linalg
from matfree.backend.typing import Array, Callable


Expand Down Expand Up @@ -82,12 +82,6 @@ def eig(Av: Callable, v0: Array, *parameters):
# Compute SVD of factorisation
vals, vecs = linalg.eig(H)
vecs = Q @ vecs

# Return in descending order
# (jnp.linalg.eig doesn't do this by default)
args = np.argsort(vals)[::-1]
vals = vals[args]
vecs = vecs[:, args]
return vals, vecs
return vals, vecs.T

return eig
26 changes: 7 additions & 19 deletions tests/test_eig/test_eig_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,16 @@ def test_equal_to_linalg_eig(nrows=7):
and the orthogonal matrices should be orthogonal. They are not unique.
"""
A = np.triu(np.arange(1.0, 1.0 + nrows**2).reshape((nrows, nrows)))

nrows, ncols = np.shape(A)
num_matvecs = min(nrows, ncols)

v0 = np.ones((ncols,))
v0 /= linalg.vector_norm(v0)
v0 = np.ones((nrows,))
num_matvecs = nrows

hessenberg = decomp.hessenberg(num_matvecs, reortho="full")
alg = eig.eig_partial(hessenberg)
S, U = alg(lambda v, p: p @ v, v0, A)

# Ensure the reference is ordered
S_, U_ = linalg.eig(A)
ordered = np.argsort(S_)[::-1]
S_ = S_[ordered]
U_ = U_[:, ordered]

assert np.allclose(S, S_)
assert np.shape(U) == np.shape(U_)
vals, vecs = alg(lambda v, p: p @ v, v0, A)

tols_decomp = {"atol": 1e-5, "rtol": 1e-5}
assert np.allclose(U @ U.T, U_ @ U_.T, **tols_decomp)
S, U = linalg.eig(A)
assert np.allclose(vecs.T @ vecs, U @ U.T, atol=1e-5, rtol=1e-5)
assert np.allclose(vals, S)


@testing.parametrize("nrows", [8])
Expand All @@ -47,4 +35,4 @@ def test_shapes_as_expected(nrows, num_matvecs):
alg = eig.eig_partial(hessenberg)
S, U = alg(lambda v, p: p @ v, v0, A)
assert S.shape == (num_matvecs,)
assert U.shape == (nrows, num_matvecs)
assert U.shape == (num_matvecs, nrows)

0 comments on commit 562ab36

Please sign in to comment.