Skip to content

Commit

Permalink
Make all decompositions have similar output types (#207)
Browse files Browse the repository at this point in the history
* Revise the output of decomp.bidiag

* Update the Hessenberg code to the new output type

* Update the tridiag API

* Update the SVD function calls

* Update the decomposition calls

* Make the DecompResult type compatible with python3.9
  • Loading branch information
pnkraemer authored Aug 29, 2024
1 parent 085624d commit 16f261b
Show file tree
Hide file tree
Showing 12 changed files with 121 additions and 107 deletions.
10 changes: 9 additions & 1 deletion matfree/backend/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

# fmt: off
from collections.abc import Callable # noqa: F401
from typing import Any, Generic, Iterable, Sequence, Tuple, TypeVar # noqa: F401, UP035
from typing import ( # noqa: F401, UP035
Any,
Generic,
Iterable,
Sequence,
Tuple,
TypeVar,
Union,
)

from jax import Array # noqa: F401

Expand Down
52 changes: 38 additions & 14 deletions matfree/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,22 @@
"""

from matfree.backend import containers, control_flow, func, linalg, np, tree_util
from matfree.backend.typing import Array, Callable
from matfree.backend.typing import Array, Callable, Union


class _DecompResult(containers.NamedTuple):
# If an algorithm returns a single Q, place it here.
# If it returns multiple Qs, stack them
# into a tuple and place them here.
Q_tall: Union[Array, tuple[Array, ...]]

# If an algorithm returns a materialized matrix,
# place it here. If it returns a sparse representation
# (e.g. two vectors representing diagonals), place it here
J_small: Union[Array, tuple[Array, ...]]

residual: Array
init_length_inv: Array


def tridiag_sym(
Expand Down Expand Up @@ -69,21 +84,19 @@ def _tridiag_reortho_full(krylov_depth: int, /, *, custom_vjp: bool, materialize
alg = hessenberg(krylov_depth, custom_vjp=custom_vjp, reortho="full")

def estimate(matvec, vec, *params):
Q, H, v, _norm = alg(matvec, vec, *params)

remainder = (v / linalg.vector_norm(v), linalg.vector_norm(v))
Q, H, v, norm = alg(matvec, vec, *params)

T = 0.5 * (H + H.T)
diags = linalg.diagonal(T, offset=0)
offdiags = linalg.diagonal(T, offset=1)

matrix = (diags, offdiags)
if materialize:
matrix = _todense_tridiag_sym(diags, offdiags)
decomposition = (Q.T, matrix)
return decomposition, remainder

decomposition = (Q.T, (diags, offdiags))
return decomposition, remainder
return _DecompResult(
Q_tall=Q, J_small=matrix, residual=v, init_length_inv=1.0 / norm
)

return estimate

Expand All @@ -97,11 +110,16 @@ def _todense_tridiag_sym(diag, off_diag):

def _tridiag_reortho_none(krylov_depth: int, /, *, custom_vjp: bool, materialize: bool):
def estimate(matvec, vec, *params):
(Q, H), v = _estimate(matvec, vec, *params)
(Q, H), (q, b) = _estimate(matvec, vec, *params)
v = b * q

if materialize:
H = _todense_tridiag_sym(*H)
return (Q, H), v

length = linalg.vector_norm(vec)
return _DecompResult(
Q_tall=Q.T, J_small=H, residual=v, init_length_inv=1.0 / length
)

def _estimate(matvec, vec, *params):
*values, _ = _tridiag_forward(matvec, krylov_depth, vec, *params)
Expand All @@ -113,8 +131,6 @@ def estimate_fwd(matvec, vec, *params):

def estimate_bwd(matvec, cache, vjp_incoming):
# Read incoming gradients and stack related quantities
print(tree_util.tree_map(np.shape, vjp_incoming))

(dxs, (dalphas, dbetas)), (dx_last, dbeta_last) = vjp_incoming
dxs = np.concatenate((dxs, dx_last[None]))
dbetas = np.concatenate((dbetas, dbeta_last[None]))
Expand Down Expand Up @@ -366,7 +382,9 @@ def forward_step(i, val):

# Loop and return
Q, H, v, _length = control_flow.fori_loop(0, k, forward_step, init)
return Q, H, v, 1 / initlength
return _DecompResult(
Q_tall=Q, J_small=H, residual=v, init_length_inv=1 / initlength
)


def _hessenberg_forward_step(Q, H, v, length, matvec, *params, idx, reortho: str):
Expand Down Expand Up @@ -553,7 +571,13 @@ def body_fun(_, s):
result = control_flow.fori_loop(
0, depth + 1, body_fun=body_fun, init_val=init_val
)
return *extract(result), 1 / length
uk_all_T, J, vk_all, (beta, vk) = extract(result)
return _DecompResult(
Q_tall=(uk_all_T, vk_all.T),
J_small=J,
residual=beta * vk,
init_length_inv=1 / length,
)

class State(containers.NamedTuple):
i: int
Expand Down
4 changes: 2 additions & 2 deletions matfree/eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ def svd_partial(
"""
# Factorise the matrix
algorithm = decomp.bidiag(depth, matrix_shape=matrix_shape, materialize=True)
u, B, vt, *_ = algorithm(Av, vA, v0)
(u, v), B, *_ = algorithm(Av, vA, v0)

# Compute SVD of factorisation
U, S, Vt = linalg.svd(B, full_matrices=False)

# Combine orthogonal transformations
return u @ U, S, Vt @ vt
return u @ U, S, Vt @ v.T


def _bidiagonal_dense(d, e):
Expand Down
8 changes: 4 additions & 4 deletions matfree/funm.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ def funm_lanczos_sym(dense_funm: Callable, tridiag_sym: Callable, /) -> Callable
def estimate(matvec: Callable, vec, *parameters):
length = linalg.vector_norm(vec)
vec /= length
(basis, matrix), _ = tridiag_sym(matvec, vec, *parameters)
Q, matrix, *_ = tridiag_sym(matvec, vec, *parameters)
# matrix = _todense_tridiag_sym(diag, off_diag)

funm = dense_funm(matrix)
e1 = np.eye(len(matrix))[0, :]
return length * (basis.T @ funm @ e1)
return length * (Q @ funm @ e1)

return estimate

Expand Down Expand Up @@ -205,7 +205,7 @@ def matvec_flat(v_flat, *p):
flat, unflatten = tree_util.ravel_pytree(Av)
return flat

(_, dense), _ = algorithm(matvec_flat, v0_flat, *parameters)
_, dense, *_ = algorithm(matvec_flat, v0_flat, *parameters)

fA = dense_funm(dense)
e1 = np.eye(len(fA))[0, :]
Expand Down Expand Up @@ -264,7 +264,7 @@ def vecmat_flat(w_flat):
# Decompose into orthogonal-bidiag-orthogonal
matvec_flat_p = lambda v: matvec_flat(v)[0] # noqa: E731
output = algorithm(matvec_flat_p, vecmat_flat, v0_flat, *parameters)
u, B, vt, *_ = output
_u, B, *_ = output

# Compute SVD of factorisation
# todo: turn the following lines into dense_funm_svd()
Expand Down
17 changes: 17 additions & 0 deletions matfree/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,20 @@ def tree_random_like(key, tree, *, generate_func=prng.normal):
flat, unflatten = tree_util.ravel_pytree(tree)
flat_like = generate_func(key, shape=flat.shape, dtype=flat.dtype)
return unflatten(flat_like)


def assert_columns_orthonormal(Q, /):
eye_like = Q.T @ Q
ref = np.eye(len(eye_like))
assert_allclose(eye_like, ref)


def assert_allclose(a, b, /):
a = np.asarray(a)
b = np.asarray(b)
tol = np.sqrt(np.finfo_eps(np.dtype(b)))

# For double precision sqrt(eps) is very tight...
if tol < 1e-7:
tol *= 10
assert np.allclose(a, b, atol=tol, rtol=tol)
64 changes: 23 additions & 41 deletions tests/test_decomp/test_bidiag.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from matfree.backend import linalg, np, prng, testing


@testing.fixture()
def A(nrows, ncols, num_significant_singular_vals):
def make_A(nrows, ncols, num_significant_singular_vals):
"""Make a positive definite matrix with certain spectrum."""
# 'Invent' a spectrum. Use the number of pre-defined eigenvalues.
n = min(nrows, ncols)
Expand All @@ -18,9 +17,11 @@ def A(nrows, ncols, num_significant_singular_vals):
@testing.parametrize("ncols", [49])
@testing.parametrize("num_significant_singular_vals", [4])
@testing.parametrize("order", [6]) # ~1.5 * num_significant_eigvals
def test_bidiag_decomposition_is_satisfied(A, order):
def test_bidiag_decomposition_is_satisfied(
nrows, ncols, num_significant_singular_vals, order
):
"""Test that Lanczos tridiagonalisation yields an orthogonal-tridiagonal decomp."""
nrows, ncols = np.shape(A)
A = make_A(nrows, ncols, num_significant_singular_vals)
key = prng.prng_key(1)
v0 = prng.normal(key, shape=(ncols,))

Expand All @@ -30,43 +31,24 @@ def Av(v):
def vA(v):
return v @ A

algorithm = decomp.bidiag(order, matrix_shape=np.shape(A), materialize=False)
Us, Bs, Vs, (b, v), ln = algorithm(Av, vA, v0)
(d_m, e_m) = Bs
algorithm = decomp.bidiag(order, matrix_shape=np.shape(A), materialize=True)
(U, V), B, res, ln = algorithm(Av, vA, v0)

tols_decomp = {"atol": 1e-5, "rtol": 1e-5}

assert np.shape(Us) == (nrows, order + 1)
assert np.allclose(Us.T @ Us, np.eye(order + 1), **tols_decomp), Us.T @ Us

assert np.shape(Vs) == (order + 1, ncols)
assert np.allclose(Vs @ Vs.T, np.eye(order + 1), **tols_decomp), Vs @ Vs.T

UAVt = Us.T @ A @ Vs.T
assert np.allclose(linalg.diagonal(UAVt), d_m, **tols_decomp)
assert np.allclose(linalg.diagonal(UAVt, 1), e_m, **tols_decomp)

B = test_util.to_dense_bidiag(d_m, e_m)
assert np.shape(B) == (order + 1, order + 1)
assert np.allclose(UAVt, B, **tols_decomp)
test_util.assert_columns_orthonormal(U)
test_util.assert_columns_orthonormal(V)

em = np.eye(order + 1)[:, -1]
AVt = A @ Vs.T
UtB = Us @ B
AtUt = A.T @ Us
VtBtb_plus_bve = Vs.T @ B.T + b * v[:, None] @ em[None, :]
assert np.allclose(AVt, UtB, **tols_decomp)
assert np.allclose(AtUt, VtBtb_plus_bve, **tols_decomp)

assert np.allclose(ln, 1.0 / linalg.vector_norm(v0))
test_util.assert_allclose(A @ V, U @ B)
test_util.assert_allclose(A.T @ U, V @ B.T + linalg.outer(res, em))
test_util.assert_allclose(1.0 / linalg.vector_norm(v0), ln)


@testing.parametrize("nrows", [5])
@testing.parametrize("ncols", [3])
@testing.parametrize("num_significant_singular_vals", [3])
def test_error_too_high_depth(A):
def test_error_too_high_depth(nrows, ncols, num_significant_singular_vals):
"""Assert that a ValueError is raised when the depth exceeds the matrix size."""
nrows, ncols = np.shape(A)
A = make_A(nrows, ncols, num_significant_singular_vals)
max_depth = min(nrows, ncols) - 1

with testing.raises(ValueError, match=""):
Expand All @@ -76,8 +58,9 @@ def test_error_too_high_depth(A):
@testing.parametrize("nrows", [5])
@testing.parametrize("ncols", [3])
@testing.parametrize("num_significant_singular_vals", [3])
def test_error_too_low_depth(A):
def test_error_too_low_depth(nrows, ncols, num_significant_singular_vals):
"""Assert that a ValueError is raised when the depth is negative."""
A = make_A(nrows, ncols, num_significant_singular_vals)
min_depth = 0
with testing.raises(ValueError, match=""):
_ = decomp.bidiag(min_depth - 1, matrix_shape=np.shape(A), materialize=False)
Expand All @@ -86,9 +69,9 @@ def test_error_too_low_depth(A):
@testing.parametrize("nrows", [15])
@testing.parametrize("ncols", [3])
@testing.parametrize("num_significant_singular_vals", [3])
def test_no_error_zero_depth(A):
def test_no_error_zero_depth(nrows, ncols, num_significant_singular_vals):
"""Assert the corner case of zero-depth does not raise an error."""
nrows, ncols = np.shape(A)
A = make_A(nrows, ncols, num_significant_singular_vals)
key = prng.prng_key(1)
v0 = prng.normal(key, shape=(ncols,))

Expand All @@ -99,12 +82,11 @@ def vA(v):
return v @ A

algorithm = decomp.bidiag(0, matrix_shape=np.shape(A), materialize=False)
Us, Bs, Vs, (b, v), ln = algorithm(Av, vA, v0)
(d_m, e_m) = Bs
assert np.shape(Us) == (nrows, 1)
assert np.shape(Vs) == (1, ncols)
(U, V), (d_m, e_m), res, ln = algorithm(Av, vA, v0)

assert np.shape(U) == (nrows, 1)
assert np.shape(V) == (ncols, 1)
assert np.shape(d_m) == (1,)
assert np.shape(e_m) == (0,)
assert np.shape(b) == ()
assert np.shape(v) == (ncols,)
assert np.shape(res) == (ncols,)
assert np.shape(ln) == ()
22 changes: 7 additions & 15 deletions tests/test_decomp/test_hessenberg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for Hessenberg factorisations (-> Arnoldi)."""

from matfree import decomp
from matfree import decomp, test_util
from matfree.backend import linalg, np, prng, testing


Expand All @@ -23,15 +23,11 @@ def test_decomposition_is_satisfied(nrows, krylov_depth, reortho, dtype):
assert r.shape == (nrows,)
assert c.shape == ()

# Tie the test-strictness to the floating point accuracy
small_value = np.sqrt(np.finfo_eps(np.dtype(H)))
tols = {"atol": small_value, "rtol": small_value}

# Test the decompositions
e0, ek = np.eye(krylov_depth)[[0, -1], :]
assert np.allclose(A @ Q - Q @ H - linalg.outer(r, ek), 0.0, **tols)
assert np.allclose(Q.T.conj() @ Q - np.eye(krylov_depth), 0.0, **tols)
assert np.allclose(Q @ e0, c * v, **tols)
test_util.assert_allclose(A @ Q - Q @ H - linalg.outer(r, ek), 0.0)
test_util.assert_allclose(Q.T.conj() @ Q - np.eye(krylov_depth), 0.0)
test_util.assert_allclose(Q @ e0, c * v)


@testing.parametrize("nrows", [10])
Expand All @@ -52,15 +48,11 @@ def test_reorthogonalisation_improves_the_estimate(nrows, krylov_depth, reortho)
assert r.shape == (nrows,)
assert c.shape == ()

# Tie the test-strictness to the floating point accuracy
small_value = np.sqrt(np.finfo_eps(np.dtype(H)))
tols = {"atol": small_value, "rtol": small_value}

# Test the decompositions
e0, ek = np.eye(krylov_depth)[[0, -1], :]
assert np.allclose(A @ Q - Q @ H - linalg.outer(r, ek), 0.0, **tols)
assert np.allclose(Q.T @ Q - np.eye(krylov_depth), 0.0, **tols)
assert np.allclose(Q @ e0, c * v, **tols)
test_util.assert_allclose(A @ Q - Q @ H - linalg.outer(r, ek), 0.0)
test_util.assert_allclose(Q.T @ Q - np.eye(krylov_depth), 0.0)
test_util.assert_allclose(Q @ e0, c * v)


def test_raises_error_for_wrong_depth_too_small():
Expand Down
Loading

0 comments on commit 16f261b

Please sign in to comment.