From fddc9a712cc722e14545e0ae123c7086fb67efd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Fri, 31 May 2024 15:54:28 +0200 Subject: [PATCH] Add a 'materialize' option to tridiag() and bidiag() to simplify downstream applications (#204) --- matfree/decomp.py | 68 +++++++++++++++---- matfree/eig.py | 5 +- matfree/funm.py | 33 ++------- tests/test_decomp/test_bidiag.py | 8 +-- tests/test_decomp/test_tridiag_sym.py | 18 ++--- tests/test_decomp/test_tridiag_sym_adjoint.py | 2 +- tests/test_funm/test_funm_lanczos_sym.py | 5 +- 7 files changed, 74 insertions(+), 65 deletions(-) diff --git a/matfree/decomp.py b/matfree/decomp.py index 9f7249d..b196a71 100644 --- a/matfree/decomp.py +++ b/matfree/decomp.py @@ -12,7 +12,14 @@ from matfree.backend.typing import Array, Callable -def tridiag_sym(krylov_depth, /, *, reortho: str = "full", custom_vjp: bool = True): +def tridiag_sym( + krylov_depth, + /, + *, + materialize: bool = True, + reortho: str = "full", + custom_vjp: bool = True, +): r"""Construct an implementation of **tridiagonalisation**. Uses pre-allocation, and full reorthogonalisation if `reortho` is set to `"full"`. @@ -44,15 +51,19 @@ def tridiag_sym(krylov_depth, /, *, reortho: str = "full", custom_vjp: bool = Tr """ if reortho == "full": - return _tridiag_reortho_full(krylov_depth, custom_vjp=custom_vjp) + return _tridiag_reortho_full( + krylov_depth, custom_vjp=custom_vjp, materialize=materialize + ) if reortho == "none": - return _tridiag_reortho_none(krylov_depth, custom_vjp=custom_vjp) + return _tridiag_reortho_none( + krylov_depth, custom_vjp=custom_vjp, materialize=materialize + ) msg = f"reortho={reortho} unsupported. Choose eiter {'full', 'none'}." raise ValueError(msg) -def _tridiag_reortho_full(krylov_depth, /, *, custom_vjp): +def _tridiag_reortho_full(krylov_depth: int, /, *, custom_vjp: bool, materialize: bool): # Implement via Arnoldi to use the reorthogonalised adjoints. # The complexity difference is minimal with full reortho. alg = hessenberg(krylov_depth, custom_vjp=custom_vjp, reortho="full") @@ -60,27 +71,50 @@ def _tridiag_reortho_full(krylov_depth, /, *, custom_vjp): def estimate(matvec, vec, *params): Q, H, v, _norm = alg(matvec, vec, *params) + remainder = (v / linalg.vector_norm(v), linalg.vector_norm(v)) + T = 0.5 * (H + H.T) diags = linalg.diagonal(T, offset=0) offdiags = linalg.diagonal(T, offset=1) + + if materialize: + matrix = _todense_tridiag_sym(diags, offdiags) + decomposition = (Q.T, matrix) + return decomposition, remainder + decomposition = (Q.T, (diags, offdiags)) - remainder = (v / linalg.vector_norm(v), linalg.vector_norm(v)) return decomposition, remainder return estimate -def _tridiag_reortho_none(krylov_depth, /, *, custom_vjp): +def _todense_tridiag_sym(diag, off_diag): + diag = linalg.diagonal_matrix(diag) + offdiag1 = linalg.diagonal_matrix(off_diag, -1) + offdiag2 = linalg.diagonal_matrix(off_diag, 1) + return diag + offdiag1 + offdiag2 + + +def _tridiag_reortho_none(krylov_depth: int, /, *, custom_vjp: bool, materialize: bool): def estimate(matvec, vec, *params): + (Q, H), v = _estimate(matvec, vec, *params) + + if materialize: + H = _todense_tridiag_sym(*H) + return (Q, H), v + + def _estimate(matvec, vec, *params): *values, _ = _tridiag_forward(matvec, krylov_depth, vec, *params) return values def estimate_fwd(matvec, vec, *params): - value = estimate(matvec, vec, *params) + value = _estimate(matvec, vec, *params) return value, (value, (linalg.vector_norm(vec), *params)) def estimate_bwd(matvec, cache, vjp_incoming): # Read incoming gradients and stack related quantities + print(tree_util.tree_map(np.shape, vjp_incoming)) + (dxs, (dalphas, dbetas)), (dx_last, dbeta_last) = vjp_incoming dxs = np.concatenate((dxs, dx_last[None])) dbetas = np.concatenate((dbetas, dbeta_last[None])) @@ -105,8 +139,8 @@ def estimate_bwd(matvec, cache, vjp_incoming): return grads if custom_vjp: - estimate = func.custom_vjp(estimate, nondiff_argnums=[0]) - estimate.defvjp(estimate_fwd, estimate_bwd) # type: ignore + _estimate = func.custom_vjp(_estimate, nondiff_argnums=[0]) + _estimate.defvjp(estimate_fwd, estimate_bwd) # type: ignore return estimate @@ -483,7 +517,7 @@ def _extract_diag(x, offset=0): return linalg.diagonal_matrix(diag, offset=offset) -def bidiag(depth: int, /, matrix_shape): +def bidiag(depth: int, /, matrix_shape, materialize: bool = True): """Construct an implementation of **bidiagonalisation**. Uses pre-allocation and full reorthogonalisation. @@ -509,10 +543,6 @@ def bidiag(depth: int, /, matrix_shape): msg3 = f"for a matrix with shape {matrix_shape}." raise ValueError(msg1 + msg2 + msg3) - # todo: move the matvecs to the estimate() functions - # of tridiag and hessenberg. Then, update the SLQ functions - # then, give all methods here a materialise=True argument - # and simplify SLQ code massively. def estimate(Av: Callable, vA: Callable, v0, *parameters): v0_norm, length = _normalise(v0) init_val = init(v0_norm) @@ -561,6 +591,11 @@ def step(Av, vA, state: State, *parameters) -> State: def extract(state: State, /): _, uk_all, vk_all, alphas, betas, beta, vk = state + + if materialize: + B = _todense_bidiag(alphas, betas[1:]) + return uk_all.T, B, vk_all, (beta, vk) + return uk_all.T, (alphas, betas[1:]), vk_all, (beta, vk) def _gram_schmidt_classical(vec, vectors): # Gram-Schmidt @@ -577,4 +612,9 @@ def _normalise(vec): length = linalg.vector_norm(vec) return vec / length, length + def _todense_bidiag(d, e): + diag = linalg.diagonal_matrix(d) + offdiag = linalg.diagonal_matrix(e, 1) + return diag + offdiag + return estimate diff --git a/matfree/eig.py b/matfree/eig.py index 35a49c0..69e392d 100644 --- a/matfree/eig.py +++ b/matfree/eig.py @@ -30,11 +30,10 @@ def svd_partial( Shape of the matrix involved in matrix-vector and vector-matrix products. """ # Factorise the matrix - algorithm = decomp.bidiag(depth, matrix_shape=matrix_shape) - u, (d, e), vt, *_ = algorithm(Av, vA, v0) + algorithm = decomp.bidiag(depth, matrix_shape=matrix_shape, materialize=True) + u, B, vt, *_ = algorithm(Av, vA, v0) # Compute SVD of factorisation - B = _bidiagonal_dense(d, e) U, S, Vt = linalg.svd(B, full_matrices=False) # Combine orthogonal transformations diff --git a/matfree/funm.py b/matfree/funm.py index 9dfd516..3b682b7 100644 --- a/matfree/funm.py +++ b/matfree/funm.py @@ -136,8 +136,8 @@ def funm_lanczos_sym(dense_funm: Callable, tridiag_sym: Callable, /) -> Callable def estimate(matvec: Callable, vec, *parameters): length = linalg.vector_norm(vec) vec /= length - (basis, (diag, off_diag)), _ = tridiag_sym(matvec, vec, *parameters) - matrix = _todense_tridiag_sym(diag, off_diag) + (basis, matrix), _ = tridiag_sym(matvec, vec, *parameters) + # matrix = _todense_tridiag_sym(diag, off_diag) funm = dense_funm(matrix) e1 = np.eye(len(matrix))[0, :] @@ -161,7 +161,7 @@ def integrand_funm_sym(matfun, order, /): """ # Todo: expect these to be passed by the user. dense_funm = dense_funm_sym_eigh(matfun) - algorithm = decomp.tridiag_sym(order) + algorithm = decomp.tridiag_sym(order, materialize=True) def quadform(matvec, v0, *parameters): v0_flat, v_unflatten = tree_util.ravel_pytree(v0) @@ -174,9 +174,8 @@ def matvec_flat(v_flat, *p): flat, unflatten = tree_util.ravel_pytree(Av) return flat - (_, (diag, off_diag)), _ = algorithm(matvec_flat, v0_flat, *parameters) + (_, dense), _ = algorithm(matvec_flat, v0_flat, *parameters) - dense = _todense_tridiag_sym(diag, off_diag) fA = dense_funm(dense) e1 = np.eye(len(fA))[0, :] return length**2 * linalg.inner(e1, fA @ e1) @@ -224,6 +223,7 @@ def matvec_flat(v_flat, *p): w0_flat, w_unflatten = func.eval_shape(matvec_flat, v0_flat) matrix_shape = (*np.shape(w0_flat), *np.shape(v0_flat)) + algorithm = decomp.bidiag(depth, matrix_shape=matrix_shape, materialize=True) def vecmat_flat(w_flat): w = w_unflatten(w_flat) @@ -231,14 +231,11 @@ def vecmat_flat(w_flat): return tree_util.ravel_pytree(wA)[0] # Decompose into orthogonal-bidiag-orthogonal - algorithm = decomp.bidiag(depth, matrix_shape=matrix_shape) matvec_flat_p = lambda v: matvec_flat(v)[0] # noqa: E731 output = algorithm(matvec_flat_p, vecmat_flat, v0_flat, *parameters) - u, (d, e), vt, *_ = output + u, B, vt, *_ = output # Compute SVD of factorisation - B = _todense_bidiag(d, e) - # todo: turn the following lines into dense_funm_svd() _, S, Vt = linalg.svd(B, full_matrices=False) @@ -277,21 +274,3 @@ def fun(dense_matrix): return linalg.funm_schur(dense_matrix, matfun) return fun - - -# todo: if we move this logic to the decomposition algorithms -# (e.g. with a materalize=True flag in the decomposition construction), -# then all trace_of_funm implementation reduce to very few lines. - - -def _todense_tridiag_sym(diag, off_diag): - diag = linalg.diagonal_matrix(diag) - offdiag1 = linalg.diagonal_matrix(off_diag, -1) - offdiag2 = linalg.diagonal_matrix(off_diag, 1) - return diag + offdiag1 + offdiag2 - - -def _todense_bidiag(d, e): - diag = linalg.diagonal_matrix(d) - offdiag = linalg.diagonal_matrix(e, 1) - return diag + offdiag diff --git a/tests/test_decomp/test_bidiag.py b/tests/test_decomp/test_bidiag.py index 805c7db..f7edbed 100644 --- a/tests/test_decomp/test_bidiag.py +++ b/tests/test_decomp/test_bidiag.py @@ -30,7 +30,7 @@ def Av(v): def vA(v): return v @ A - algorithm = decomp.bidiag(order, matrix_shape=np.shape(A)) + algorithm = decomp.bidiag(order, matrix_shape=np.shape(A), materialize=False) Us, Bs, Vs, (b, v), ln = algorithm(Av, vA, v0) (d_m, e_m) = Bs @@ -70,7 +70,7 @@ def test_error_too_high_depth(A): max_depth = min(nrows, ncols) - 1 with testing.raises(ValueError, match=""): - _ = decomp.bidiag(max_depth + 1, matrix_shape=np.shape(A)) + _ = decomp.bidiag(max_depth + 1, matrix_shape=np.shape(A), materialize=False) @testing.parametrize("nrows", [5]) @@ -80,7 +80,7 @@ def test_error_too_low_depth(A): """Assert that a ValueError is raised when the depth is negative.""" min_depth = 0 with testing.raises(ValueError, match=""): - _ = decomp.bidiag(min_depth - 1, matrix_shape=np.shape(A)) + _ = decomp.bidiag(min_depth - 1, matrix_shape=np.shape(A), materialize=False) @testing.parametrize("nrows", [15]) @@ -98,7 +98,7 @@ def Av(v): def vA(v): return v @ A - algorithm = decomp.bidiag(0, matrix_shape=np.shape(A)) + algorithm = decomp.bidiag(0, matrix_shape=np.shape(A), materialize=False) Us, Bs, Vs, (b, v), ln = algorithm(Av, vA, v0) (d_m, e_m) = Bs assert np.shape(Us) == (nrows, 1) diff --git a/tests/test_decomp/test_tridiag_sym.py b/tests/test_decomp/test_tridiag_sym.py index b925aa4..e2baeef 100644 --- a/tests/test_decomp/test_tridiag_sym.py +++ b/tests/test_decomp/test_tridiag_sym.py @@ -13,11 +13,10 @@ def test_full_rank_reconstruction_is_exact(reortho, ndim): vector = np.flip(np.arange(1.0, 1.0 + len(eigvals))) # Run Lanczos approximation - algorithm = decomp.tridiag_sym(ndim, reortho=reortho) - (lanczos_vectors, tridiag), _ = algorithm(lambda s, p: p @ s, vector, matrix) + algorithm = decomp.tridiag_sym(ndim, reortho=reortho, materialize=True) + (lanczos_vectors, dense_matrix), _ = algorithm(lambda s, p: p @ s, vector, matrix) # Reconstruct the original matrix from the full-order approximation - dense_matrix = _dense_tridiag_sym(*tridiag) matrix_reconstructed = lanczos_vectors.T @ dense_matrix @ lanczos_vectors if reortho == "full": @@ -46,19 +45,10 @@ def test_mid_rank_reconstruction_satisfies_decomposition(ndim, krylov_depth, reo vector = np.flip(np.arange(1.0, 1.0 + len(eigvals))) # Run Lanczos approximation - algorithm = decomp.tridiag_sym(krylov_depth, reortho=reortho) - (lanczos_vectors, tridiag), (q, b) = algorithm(lambda s, p: p @ s, vector, matrix) + algorithm = decomp.tridiag_sym(krylov_depth, reortho=reortho, materialize=True) + (Q, T), (q, b) = algorithm(lambda s, p: p @ s, vector, matrix) # Verify the decomposition - Q, T = lanczos_vectors, _dense_tridiag_sym(*tridiag) tols = {"atol": 1e-5, "rtol": 1e-5} e_K = np.eye(krylov_depth)[-1] assert np.allclose(matrix @ Q.T, Q.T @ T + linalg.outer(e_K, q * b).T, **tols) - - -def _dense_tridiag_sym(diagonal, off_diagonal): - return ( - linalg.diagonal_matrix(diagonal) - + linalg.diagonal_matrix(off_diagonal, 1) - + linalg.diagonal_matrix(off_diagonal, -1) - ) diff --git a/tests/test_decomp/test_tridiag_sym_adjoint.py b/tests/test_decomp/test_tridiag_sym_adjoint.py index f464c55..997bea2 100644 --- a/tests/test_decomp/test_tridiag_sym_adjoint.py +++ b/tests/test_decomp/test_tridiag_sym_adjoint.py @@ -23,7 +23,7 @@ def matvec(s, p): # Construct a vector-to-vector decomposition function def decompose(f, *, custom_vjp): - kwargs = {"reortho": reortho, "custom_vjp": custom_vjp} + kwargs = {"reortho": reortho, "custom_vjp": custom_vjp, "materialize": False} algorithm = decomp.tridiag_sym(krylov_order, **kwargs) output = algorithm(matvec, *unflatten(f)) return tree_util.ravel_pytree(output)[0] diff --git a/tests/test_funm/test_funm_lanczos_sym.py b/tests/test_funm/test_funm_lanczos_sym.py index 45052cf..024704a 100644 --- a/tests/test_funm/test_funm_lanczos_sym.py +++ b/tests/test_funm/test_funm_lanczos_sym.py @@ -5,7 +5,8 @@ @testing.parametrize("dense_funm", [funm.dense_funm_sym_eigh, funm.dense_funm_schur]) -def test_funm_lanczos_sym_matches_eigh_implementation(dense_funm, n=11): +@testing.parametrize("reortho", ["full", "none"]) +def test_funm_lanczos_sym_matches_eigh_implementation(dense_funm, reortho, n=11): """Test matrix-function-vector products via Lanczos' algorithm.""" # Create a test-problem: matvec, matrix function, # vector, and parameters (a matrix). @@ -28,7 +29,7 @@ def fun(x): # Compute the matrix-function vector product dense_funm = dense_funm(fun) - lanczos = decomp.tridiag_sym(6) + lanczos = decomp.tridiag_sym(6, materialize=True, reortho=reortho) matfun_vec = funm.funm_lanczos_sym(dense_funm, lanczos) received = matfun_vec(matvec, v, matrix) assert np.allclose(expected, received, atol=1e-6)