From 5c1133deecd00a45b7890c2e6999ecab04ac9762 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Wed, 10 Jan 2024 13:19:53 +0100 Subject: [PATCH 1/2] Make integrand_slq_spd behave correctly for arbitrary input vectors, not just unit-second-moment input vectors --- matfree/slq.py | 7 +++---- tests/test_slq/test_logdet_spd.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/matfree/slq.py b/matfree/slq.py index 0b262dc..511bea0 100644 --- a/matfree/slq.py +++ b/matfree/slq.py @@ -25,7 +25,8 @@ def integrand_slq_spd(matfun, order, matvec, /): def quadform(v0, *parameters): v0_flat, v_unflatten = tree_util.ravel_pytree(v0) - v0_flat /= linalg.vector_norm(v0_flat) + length = linalg.vector_norm(v0_flat) + v0_flat /= length def matvec_flat(v_flat): v = v_unflatten(v_flat) @@ -50,10 +51,8 @@ def matvec_flat(v_flat): # Since Q orthogonal (orthonormal) to v0, Q v = Q[0], # and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0] - (dim,) = v0_flat.shape - fx_eigvals = func.vmap(matfun)(eigvals) - return dim * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :]) + return length**2 * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :]) return quadform diff --git a/tests/test_slq/test_logdet_spd.py b/tests/test_slq/test_logdet_spd.py index fc43b5e..40b08c3 100644 --- a/tests/test_slq/test_logdet_spd.py +++ b/tests/test_slq/test_logdet_spd.py @@ -36,3 +36,31 @@ def matvec(x): expected = linalg.slogdet(A)[1] print_if_assert_fails = ("error", np.abs(received - expected), "target:", expected) assert np.allclose(received, expected, atol=1e-2, rtol=1e-2), print_if_assert_fails + + +@testing.parametrize("n", [200]) +# usually: ~1.5 * num_significant_eigvals. +# But logdet seems to converge sooo much faster. +def test_logdet_spd_exact_for_full_order_lanczos(n): + r"""Computing v^\top f(A) v with max-order Lanczos should be exact for _any_ v.""" + # Construct a (numerically nice) matrix + eigvals = np.arange(1.0, 1.0 + n, step=1.0) + A = test_util.symmetric_matrix_from_eigenvalues(eigvals) + + # Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm + order = n - 1 + integrand = slq.integrand_logdet_spd(order, lambda v: A @ v) + + # Construct a vector without that does not have expected 2-norm equal to "dim" + x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 20 + + # Compute v^\top @ log(A) @ v via Lanczos + received = integrand(x) + + # Compute the "true" value of v^\top @ log(A) @ v via eigenvalues + eigvals, eigvecs = linalg.eigh(A) + logA = eigvecs @ linalg.diagonal_matrix(np.log(eigvals)) @ eigvecs.T + expected = x.T @ logA @ x + + # They should be identical + assert np.allclose(received, expected) From 7ce145533be650314d796721cff677b9787ef1af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Wed, 10 Jan 2024 13:30:34 +0100 Subject: [PATCH 2/2] Make integrand_slq_product behave correctly, too --- matfree/slq.py | 6 +++--- tests/test_slq/test_logdet_product.py | 30 +++++++++++++++++++++++++++ tests/test_slq/test_logdet_spd.py | 4 ++-- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/matfree/slq.py b/matfree/slq.py index 511bea0..3cfac3f 100644 --- a/matfree/slq.py +++ b/matfree/slq.py @@ -85,7 +85,8 @@ def integrand_slq_product(matfun, depth, matvec, vecmat, /): def quadform(v0, *parameters): v0_flat, v_unflatten = tree_util.ravel_pytree(v0) - v0_flat /= linalg.vector_norm(v0_flat) + length = linalg.vector_norm(v0_flat) + v0_flat /= length def matvec_flat(v_flat): v = v_unflatten(v_flat) @@ -114,10 +115,9 @@ def vecmat_flat(w_flat): # Since Q orthogonal (orthonormal) to v0, Q v = Q[0], # and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0] - _, ncols = matrix_shape eigvals, eigvecs = S**2, Vt.T fx_eigvals = func.vmap(matfun)(eigvals) - return ncols * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :]) + return length**2 * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :]) return quadform diff --git a/tests/test_slq/test_logdet_product.py b/tests/test_slq/test_logdet_product.py index 8b765d3..2874496 100644 --- a/tests/test_slq/test_logdet_product.py +++ b/tests/test_slq/test_logdet_product.py @@ -38,3 +38,33 @@ def vecmat(x): expected = linalg.slogdet(A.T @ A)[1] print_if_assert_fails = ("error", np.abs(received - expected), "target:", expected) assert np.allclose(received, expected, atol=1e-2, rtol=1e-2), print_if_assert_fails + + +@testing.parametrize("n", [50]) +# usually: ~1.5 * num_significant_eigvals. +# But logdet seems to converge sooo much faster. +def test_logdet_product_exact_for_full_order_lanczos(n): + r"""Computing v^\top f(A^\top @ A) v with max-order Lanczos is exact for _any_ v.""" + # Construct a (numerically nice) matrix + singular_values = np.sqrt(np.arange(1.0, 1.0 + n, step=1.0)) + A = test_util.asymmetric_matrix_from_singular_values( + singular_values, nrows=n, ncols=n + ) + + # Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm + order = n - 1 + integrand = slq.integrand_logdet_product(order, lambda v: A @ v, lambda v: v @ A) + + # Construct a vector without that does not have expected 2-norm equal to "dim" + x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 1 + + # Compute v^\top @ log(A) @ v via Lanczos + received = integrand(x) + + # Compute the "true" value of v^\top @ log(A) @ v via eigenvalues + eigvals, eigvecs = linalg.eigh(A.T @ A) + logA = eigvecs @ linalg.diagonal_matrix(np.log(eigvals)) @ eigvecs.T + expected = x.T @ logA @ x + + # They should be identical + assert np.allclose(received, expected) diff --git a/tests/test_slq/test_logdet_spd.py b/tests/test_slq/test_logdet_spd.py index 40b08c3..f03b2c1 100644 --- a/tests/test_slq/test_logdet_spd.py +++ b/tests/test_slq/test_logdet_spd.py @@ -38,7 +38,7 @@ def matvec(x): assert np.allclose(received, expected, atol=1e-2, rtol=1e-2), print_if_assert_fails -@testing.parametrize("n", [200]) +@testing.parametrize("n", [50]) # usually: ~1.5 * num_significant_eigvals. # But logdet seems to converge sooo much faster. def test_logdet_spd_exact_for_full_order_lanczos(n): @@ -52,7 +52,7 @@ def test_logdet_spd_exact_for_full_order_lanczos(n): integrand = slq.integrand_logdet_spd(order, lambda v: A @ v) # Construct a vector without that does not have expected 2-norm equal to "dim" - x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 20 + x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 10 # Compute v^\top @ log(A) @ v via Lanczos received = integrand(x)