Skip to content

Commit

Permalink
Implement funm_arnoldi to enable functions of non-structured matrices (
Browse files Browse the repository at this point in the history
…#206)

* Implement a matrix-function that uses the Arnoldi iteration

* Remove jnp.set_printoptions from doctests because it caused problems

* Clarify the test-framework for funm_arnoldi to enable future tests with e.g. log
  • Loading branch information
pnkraemer authored Aug 29, 2024
1 parent 38f3f8b commit 085624d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 0 deletions.
4 changes: 4 additions & 0 deletions matfree/backend/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,7 @@ def cg(Av, b, /):

def funm_schur(A, f, /):
return jax.scipy.linalg.funm(A, f)


def funm_pade_exp(A, /):
return jax.scipy.linalg.expm(A)
46 changes: 46 additions & 0 deletions matfree/funm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
>>> fAx = matfun_vec(lambda s: A @ s, v)
>>> print(fAx.shape)
(10,)
"""

from matfree import decomp
Expand Down Expand Up @@ -145,6 +146,37 @@ def estimate(matvec: Callable, vec, *parameters):
return estimate


def funm_arnoldi(dense_funm: Callable, hessenberg: Callable, /) -> Callable:
"""Implement a matrix-function-vector product via the Arnoldi iteration.
This algorithm uses the Arnoldi iteration
and therefore applies only to all square matrices.
Parameters
----------
dense_funm
An implementation of a function of a dense matrix.
For example, the output of
[funm.dense_funm_sym_eigh][matfree.funm.dense_funm_sym_eigh]
[funm.dense_funm_schur][matfree.funm.dense_funm_schur]
hessenberg
An implementation of Hessenberg-factorisation.
E.g., the output of
[decomp.hessenberg][matfree.decomp.hessenberg].
"""

def estimate(matvec: Callable, vec, *parameters):
length = linalg.vector_norm(vec)
vec /= length
basis, matrix, *_ = hessenberg(matvec, vec, *parameters)

funm = dense_funm(matrix)
e1 = np.eye(len(matrix))[0, :]
return length * (basis @ funm @ e1)

return estimate


def integrand_funm_sym_logdet(order, /):
"""Construct the integrand for the log-determinant.
Expand Down Expand Up @@ -247,6 +279,7 @@ def vecmat_flat(w_flat):
return quadform


# todo: rename to *_eigh_sym
def dense_funm_sym_eigh(matfun):
"""Implement dense matrix-functions via symmetric eigendecompositions.
Expand All @@ -273,3 +306,16 @@ def fun(dense_matrix):
return linalg.funm_schur(dense_matrix, matfun)

return fun


def dense_funm_pade_exp():
"""Implement dense matrix-exponentials using a Pade approximation.
Use it to construct one of the matrix-free matrix-function implementations,
e.g. [matfree.funm.funm_arnoldi][matfree.funm.funm_arnoldi].
"""

def fun(dense_matrix):
return linalg.funm_pade_exp(dense_matrix)

return fun
29 changes: 29 additions & 0 deletions tests/test_funm/test_funm_arnoldi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Test matrix-function-vector products via the Arnoldi iteration."""

from matfree import decomp, funm
from matfree.backend import np, prng, testing


def case_expm():
return funm.dense_funm_pade_exp()


@testing.parametrize("reortho", ["full", "none"])
@testing.parametrize_with_cases("dense_funm", cases=".", prefix="case_")
def test_funm_arnoldi_matches_schur_implementation(dense_funm, reortho, n=11):
"""Test matrix-function-vector products via the Arnoldi iteration."""
# Create a test-problem: matvec, matrix function,
# vector, and parameters (a matrix).

matrix = prng.normal(prng.prng_key(1), shape=(n, n))
v = prng.normal(prng.prng_key(2), shape=(n,))

# Compute the solution
expected = dense_funm(matrix) @ v

# Compute the matrix-function vector product
arnoldi = decomp.hessenberg((n * 3) // 4, reortho=reortho)
matfun_vec = funm.funm_arnoldi(dense_funm, arnoldi)
received = matfun_vec(lambda s, p: p @ s, v, matrix)

assert np.allclose(expected, received, rtol=1e-1, atol=1e-1)

0 comments on commit 085624d

Please sign in to comment.