Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make SLQ integrands behave correctly for arbitrary input vectors #174

Merged
merged 2 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions matfree/slq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def integrand_slq_spd(matfun, order, matvec, /):

def quadform(v0, *parameters):
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)
v0_flat /= linalg.vector_norm(v0_flat)
length = linalg.vector_norm(v0_flat)
v0_flat /= length

def matvec_flat(v_flat):
v = v_unflatten(v_flat)
Expand All @@ -50,10 +51,8 @@ def matvec_flat(v_flat):

# Since Q orthogonal (orthonormal) to v0, Q v = Q[0],
# and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0]
(dim,) = v0_flat.shape

fx_eigvals = func.vmap(matfun)(eigvals)
return dim * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :])
return length**2 * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :])

return quadform

Expand Down Expand Up @@ -86,7 +85,8 @@ def integrand_slq_product(matfun, depth, matvec, vecmat, /):

def quadform(v0, *parameters):
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)
v0_flat /= linalg.vector_norm(v0_flat)
length = linalg.vector_norm(v0_flat)
v0_flat /= length

def matvec_flat(v_flat):
v = v_unflatten(v_flat)
Expand Down Expand Up @@ -115,10 +115,9 @@ def vecmat_flat(w_flat):

# Since Q orthogonal (orthonormal) to v0, Q v = Q[0],
# and therefore (Q v)^T f(D) (Qv) = Q[0] * f(diag) * Q[0]
_, ncols = matrix_shape
eigvals, eigvecs = S**2, Vt.T
fx_eigvals = func.vmap(matfun)(eigvals)
return ncols * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :])
return length**2 * linalg.vecdot(eigvecs[0, :], fx_eigvals * eigvecs[0, :])

return quadform

Expand Down
30 changes: 30 additions & 0 deletions tests/test_slq/test_logdet_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,33 @@ def vecmat(x):
expected = linalg.slogdet(A.T @ A)[1]
print_if_assert_fails = ("error", np.abs(received - expected), "target:", expected)
assert np.allclose(received, expected, atol=1e-2, rtol=1e-2), print_if_assert_fails


@testing.parametrize("n", [50])
# usually: ~1.5 * num_significant_eigvals.
# But logdet seems to converge sooo much faster.
def test_logdet_product_exact_for_full_order_lanczos(n):
r"""Computing v^\top f(A^\top @ A) v with max-order Lanczos is exact for _any_ v."""
# Construct a (numerically nice) matrix
singular_values = np.sqrt(np.arange(1.0, 1.0 + n, step=1.0))
A = test_util.asymmetric_matrix_from_singular_values(
singular_values, nrows=n, ncols=n
)

# Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm
order = n - 1
integrand = slq.integrand_logdet_product(order, lambda v: A @ v, lambda v: v @ A)

# Construct a vector without that does not have expected 2-norm equal to "dim"
x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 1

# Compute v^\top @ log(A) @ v via Lanczos
received = integrand(x)

# Compute the "true" value of v^\top @ log(A) @ v via eigenvalues
eigvals, eigvecs = linalg.eigh(A.T @ A)
logA = eigvecs @ linalg.diagonal_matrix(np.log(eigvals)) @ eigvecs.T
expected = x.T @ logA @ x

# They should be identical
assert np.allclose(received, expected)
28 changes: 28 additions & 0 deletions tests/test_slq/test_logdet_spd.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,31 @@ def matvec(x):
expected = linalg.slogdet(A)[1]
print_if_assert_fails = ("error", np.abs(received - expected), "target:", expected)
assert np.allclose(received, expected, atol=1e-2, rtol=1e-2), print_if_assert_fails


@testing.parametrize("n", [50])
# usually: ~1.5 * num_significant_eigvals.
# But logdet seems to converge sooo much faster.
def test_logdet_spd_exact_for_full_order_lanczos(n):
r"""Computing v^\top f(A) v with max-order Lanczos should be exact for _any_ v."""
# Construct a (numerically nice) matrix
eigvals = np.arange(1.0, 1.0 + n, step=1.0)
A = test_util.symmetric_matrix_from_eigenvalues(eigvals)

# Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm
order = n - 1
integrand = slq.integrand_logdet_spd(order, lambda v: A @ v)

# Construct a vector without that does not have expected 2-norm equal to "dim"
x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 10

# Compute v^\top @ log(A) @ v via Lanczos
received = integrand(x)

# Compute the "true" value of v^\top @ log(A) @ v via eigenvalues
eigvals, eigvecs = linalg.eigh(A)
logA = eigvecs @ linalg.diagonal_matrix(np.log(eigvals)) @ eigvecs.T
expected = x.T @ logA @ x

# They should be identical
assert np.allclose(received, expected)
Loading