Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement funm_arnoldi to enable functions of non-structured matrices #206

Merged
merged 4 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions matfree/backend/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,7 @@ def cg(Av, b, /):

def funm_schur(A, f, /):
return jax.scipy.linalg.funm(A, f)


def funm_pade_exp(A, /):
return jax.scipy.linalg.expm(A)
46 changes: 46 additions & 0 deletions matfree/funm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
>>> fAx = matfun_vec(lambda s: A @ s, v)
>>> print(fAx.shape)
(10,)

"""

from matfree import decomp
Expand Down Expand Up @@ -145,6 +146,37 @@ def estimate(matvec: Callable, vec, *parameters):
return estimate


def funm_arnoldi(dense_funm: Callable, hessenberg: Callable, /) -> Callable:
"""Implement a matrix-function-vector product via the Arnoldi iteration.

This algorithm uses the Arnoldi iteration
and therefore applies only to all square matrices.

Parameters
----------
dense_funm
An implementation of a function of a dense matrix.
For example, the output of
[funm.dense_funm_sym_eigh][matfree.funm.dense_funm_sym_eigh]
[funm.dense_funm_schur][matfree.funm.dense_funm_schur]
hessenberg
An implementation of Hessenberg-factorisation.
E.g., the output of
[decomp.hessenberg][matfree.decomp.hessenberg].
"""

def estimate(matvec: Callable, vec, *parameters):
length = linalg.vector_norm(vec)
vec /= length
basis, matrix, *_ = hessenberg(matvec, vec, *parameters)

funm = dense_funm(matrix)
e1 = np.eye(len(matrix))[0, :]
return length * (basis @ funm @ e1)

return estimate


def integrand_funm_sym_logdet(order, /):
"""Construct the integrand for the log-determinant.

Expand Down Expand Up @@ -247,6 +279,7 @@ def vecmat_flat(w_flat):
return quadform


# todo: rename to *_eigh_sym
def dense_funm_sym_eigh(matfun):
"""Implement dense matrix-functions via symmetric eigendecompositions.

Expand All @@ -273,3 +306,16 @@ def fun(dense_matrix):
return linalg.funm_schur(dense_matrix, matfun)

return fun


def dense_funm_pade_exp():
"""Implement dense matrix-exponentials using a Pade approximation.

Use it to construct one of the matrix-free matrix-function implementations,
e.g. [matfree.funm.funm_arnoldi][matfree.funm.funm_arnoldi].
"""

def fun(dense_matrix):
return linalg.funm_pade_exp(dense_matrix)

return fun
29 changes: 29 additions & 0 deletions tests/test_funm/test_funm_arnoldi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Test matrix-function-vector products via the Arnoldi iteration."""

from matfree import decomp, funm
from matfree.backend import np, prng, testing


def case_expm():
return funm.dense_funm_pade_exp()


@testing.parametrize("reortho", ["full", "none"])
@testing.parametrize_with_cases("dense_funm", cases=".", prefix="case_")
def test_funm_arnoldi_matches_schur_implementation(dense_funm, reortho, n=11):
"""Test matrix-function-vector products via the Arnoldi iteration."""
# Create a test-problem: matvec, matrix function,
# vector, and parameters (a matrix).

matrix = prng.normal(prng.prng_key(1), shape=(n, n))
v = prng.normal(prng.prng_key(2), shape=(n,))

# Compute the solution
expected = dense_funm(matrix) @ v

# Compute the matrix-function vector product
arnoldi = decomp.hessenberg((n * 3) // 4, reortho=reortho)
matfun_vec = funm.funm_arnoldi(dense_funm, arnoldi)
received = matfun_vec(lambda s, p: p @ s, v, matrix)

assert np.allclose(expected, received, rtol=1e-1, atol=1e-1)
Loading