From 08b694f9d04314c4660e036cc2f374f68af46252 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Thu, 29 Aug 2024 08:45:39 +0200 Subject: [PATCH 1/3] Implement a matrix-function that uses the Arnoldi iteration --- matfree/backend/linalg.py | 4 +++ matfree/funm.py | 44 ++++++++++++++++++++++++++++ tests/test_funm/test_funm_arnoldi.py | 26 ++++++++++++++++ 3 files changed, 74 insertions(+) create mode 100644 tests/test_funm/test_funm_arnoldi.py diff --git a/matfree/backend/linalg.py b/matfree/backend/linalg.py index 8ca87ec..c13359d 100644 --- a/matfree/backend/linalg.py +++ b/matfree/backend/linalg.py @@ -81,3 +81,7 @@ def cg(Av, b, /): def funm_schur(A, f, /): return jax.scipy.linalg.funm(A, f) + + +def funm_exp_pade(A, /): + return jax.scipy.linalg.expm(A) diff --git a/matfree/funm.py b/matfree/funm.py index 3b682b7..b3f0839 100644 --- a/matfree/funm.py +++ b/matfree/funm.py @@ -146,6 +146,37 @@ def estimate(matvec: Callable, vec, *parameters): return estimate +def funm_arnoldi(dense_funm: Callable, hessenberg: Callable, /) -> Callable: + """Implement a matrix-function-vector product via the Arnoldi iteration. + + This algorithm uses the Arnoldi iteration + and therefore applies only to all square matrices. + + Parameters + ---------- + dense_funm + An implementation of a function of a dense matrix. + For example, the output of + [funm.dense_funm_sym_eigh][matfree.funm.dense_funm_sym_eigh] + [funm.dense_funm_schur][matfree.funm.dense_funm_schur] + hessenberg + An implementation of Hessenberg-factorisation. + E.g., the output of + [decomp.hessenberg][matfree.decomp.hessenberg]. + """ + + def estimate(matvec: Callable, vec, *parameters): + length = linalg.vector_norm(vec) + vec /= length + basis, matrix, *_ = hessenberg(matvec, vec, *parameters) + + funm = dense_funm(matrix) + e1 = np.eye(len(matrix))[0, :] + return length * (basis @ funm @ e1) + + return estimate + + def integrand_funm_sym_logdet(order, /): """Construct the integrand for the log-determinant. @@ -274,3 +305,16 @@ def fun(dense_matrix): return linalg.funm_schur(dense_matrix, matfun) return fun + + +def dense_funm_exp_pade(): + """Implement dense matrix-exponentials using a Pade approximation. + + Use it to construct one of the matrix-free matrix-function implementations, + e.g. [matfree.funm.funm_arnoldi][matfree.funm.funm_arnoldi]. + """ + + def fun(dense_matrix): + return linalg.funm_exp_pade(dense_matrix) + + return fun diff --git a/tests/test_funm/test_funm_arnoldi.py b/tests/test_funm/test_funm_arnoldi.py new file mode 100644 index 0000000..5bb50b6 --- /dev/null +++ b/tests/test_funm/test_funm_arnoldi.py @@ -0,0 +1,26 @@ +"""Test matrix-function-vector products via the Arnoldi iteration.""" + +from matfree import decomp, funm +from matfree.backend import np, prng, testing + + +@testing.parametrize("reortho", ["full", "none"]) +def test_funm_arnoldi_matches_schur_implementation(reortho, n=11): + """Test matrix-function-vector products via the Arnoldi iteration.""" + # Create a test-problem: matvec, matrix function, + # vector, and parameters (a matrix). + + matrix = prng.normal(prng.prng_key(1), shape=(n, n)) + v = prng.normal(prng.prng_key(2), shape=(n,)) + + # Compute the solution + dense_funm = funm.dense_funm_exp_pade() + expected = dense_funm(matrix) @ v + + # Compute the matrix-function vector product + # We use + arnoldi = decomp.hessenberg((n * 3) // 4, reortho=reortho) + matfun_vec = funm.funm_arnoldi(dense_funm, arnoldi) + received = matfun_vec(lambda s, p: p @ s, v, matrix) + + assert np.allclose(expected, received, rtol=1e-1, atol=1e-1) From a44bf31a638b245d9def3baccc1694462dee711d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Thu, 29 Aug 2024 09:00:44 +0200 Subject: [PATCH 2/3] Remove jnp.set_printoptions from doctests because it caused problems --- README.md | 2 -- matfree/funm.py | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 2a0a32b..95aaebc 100644 --- a/README.md +++ b/README.md @@ -49,8 +49,6 @@ Import matfree and JAX, and set up a test problem. >>> import jax.numpy as jnp >>> from matfree import stochtrace >>> ->>> jnp.set_printoptions(1) - >>> A = jnp.reshape(jnp.arange(12.0), (6, 2)) >>> >>> def matvec(x): diff --git a/matfree/funm.py b/matfree/funm.py index b3f0839..2f6499a 100644 --- a/matfree/funm.py +++ b/matfree/funm.py @@ -23,8 +23,6 @@ >>> import jax.numpy as jnp >>> from matfree import decomp >>> ->>> jnp.set_printoptions(1) ->>> >>> M = jax.random.normal(jax.random.PRNGKey(1), shape=(10, 10)) >>> A = M.T @ M >>> v = jax.random.normal(jax.random.PRNGKey(2), shape=(10,)) @@ -33,8 +31,10 @@ >>> matfun = dense_funm_sym_eigh(jnp.log) >>> tridiag = decomp.tridiag_sym(4) >>> matfun_vec = funm_lanczos_sym(matfun, tridiag) ->>> matfun_vec(lambda s: A @ s, v) -Array([-4.1, -1.3, -2.2, -2.1, -1.2, -3.3, -0.2, 0.3, 0.7, 0.9], dtype=float32) +>>> fAx = matfun_vec(lambda s: A @ s, v) +>>> print(fAx.shape) +(10,) + """ from matfree import decomp From ffbea9baa63b55e5bffa58d108535ea68521eb68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Thu, 29 Aug 2024 09:12:22 +0200 Subject: [PATCH 3/3] Clarify the test-framework for funm_arnoldi to enable future tests with e.g. log --- matfree/backend/linalg.py | 2 +- matfree/funm.py | 5 +++-- tests/test_funm/test_funm_arnoldi.py | 9 ++++++--- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/matfree/backend/linalg.py b/matfree/backend/linalg.py index c13359d..b5f1fc7 100644 --- a/matfree/backend/linalg.py +++ b/matfree/backend/linalg.py @@ -83,5 +83,5 @@ def funm_schur(A, f, /): return jax.scipy.linalg.funm(A, f) -def funm_exp_pade(A, /): +def funm_pade_exp(A, /): return jax.scipy.linalg.expm(A) diff --git a/matfree/funm.py b/matfree/funm.py index 2f6499a..4137a29 100644 --- a/matfree/funm.py +++ b/matfree/funm.py @@ -279,6 +279,7 @@ def vecmat_flat(w_flat): return quadform +# todo: rename to *_eigh_sym def dense_funm_sym_eigh(matfun): """Implement dense matrix-functions via symmetric eigendecompositions. @@ -307,7 +308,7 @@ def fun(dense_matrix): return fun -def dense_funm_exp_pade(): +def dense_funm_pade_exp(): """Implement dense matrix-exponentials using a Pade approximation. Use it to construct one of the matrix-free matrix-function implementations, @@ -315,6 +316,6 @@ def dense_funm_exp_pade(): """ def fun(dense_matrix): - return linalg.funm_exp_pade(dense_matrix) + return linalg.funm_pade_exp(dense_matrix) return fun diff --git a/tests/test_funm/test_funm_arnoldi.py b/tests/test_funm/test_funm_arnoldi.py index 5bb50b6..c122e1d 100644 --- a/tests/test_funm/test_funm_arnoldi.py +++ b/tests/test_funm/test_funm_arnoldi.py @@ -4,8 +4,13 @@ from matfree.backend import np, prng, testing +def case_expm(): + return funm.dense_funm_pade_exp() + + @testing.parametrize("reortho", ["full", "none"]) -def test_funm_arnoldi_matches_schur_implementation(reortho, n=11): +@testing.parametrize_with_cases("dense_funm", cases=".", prefix="case_") +def test_funm_arnoldi_matches_schur_implementation(dense_funm, reortho, n=11): """Test matrix-function-vector products via the Arnoldi iteration.""" # Create a test-problem: matvec, matrix function, # vector, and parameters (a matrix). @@ -14,11 +19,9 @@ def test_funm_arnoldi_matches_schur_implementation(reortho, n=11): v = prng.normal(prng.prng_key(2), shape=(n,)) # Compute the solution - dense_funm = funm.dense_funm_exp_pade() expected = dense_funm(matrix) @ v # Compute the matrix-function vector product - # We use arnoldi = decomp.hessenberg((n * 3) // 4, reortho=reortho) matfun_vec = funm.funm_arnoldi(dense_funm, arnoldi) received = matfun_vec(lambda s, p: p @ s, v, matrix)