Skip to content

Commit

Permalink
Make integrand_funm_product_* expect algorithms as inputs (not algori…
Browse files Browse the repository at this point in the history
…thm parameters) (#212)

* Make bidiag-related matrix functions expect algorithm inputs instead of parameter inputs

* Make funm_bidiag expect a dense matrix function algorithm as an input

* Fix a tutorial
  • Loading branch information
pnkraemer authored Aug 29, 2024
1 parent 09a5f21 commit d7f1564
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 48 deletions.
26 changes: 16 additions & 10 deletions matfree/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def _extract_diag(x, offset=0):
return linalg.diagonal_matrix(diag, offset=offset)


def bidiag(depth: int, /, matrix_shape, materialize: bool = True):
def bidiag(depth: int, /, materialize: bool = True):
"""Construct an implementation of **bidiagonalisation**.
Uses pre-allocation and full reorthogonalisation.
Expand All @@ -553,17 +553,23 @@ def bidiag(depth: int, /, matrix_shape, materialize: bool = True):
consider using [tridiag_sym][matfree.decomp.tridiag_sym] for the time being.
"""
nrows, ncols = matrix_shape
max_depth = min(nrows, ncols) - 1
if depth > max_depth or depth < 0:
msg1 = f"Depth {depth} exceeds the matrix' dimensions. "
msg2 = f"Expected: 0 <= depth <= min(nrows, ncols) - 1 = {max_depth} "
msg3 = f"for a matrix with shape {matrix_shape}."
raise ValueError(msg1 + msg2 + msg3)

def estimate(Av: Callable, vA: Callable, v0, *parameters):
# Infer the size of A from v0
(ncols,) = np.shape(v0)
w0_like = func.eval_shape(Av, v0)
(nrows,) = np.shape(w0_like)

# Complain if the shapes don't match
max_depth = min(nrows, ncols) - 1
if depth > max_depth or depth < 0:
msg1 = f"Depth {depth} exceeds the matrix' dimensions. "
msg2 = f"Expected: 0 <= depth <= min(nrows, ncols) - 1 = {max_depth} "
msg3 = f"for a matrix with shape {(nrows, ncols)}."
raise ValueError(msg1 + msg2 + msg3)

v0_norm, length = _normalise(v0)
init_val = init(v0_norm)
init_val = init(v0_norm, nrows=nrows, ncols=ncols)

def body_fun(_, s):
return step(Av, vA, s, *parameters)
Expand All @@ -588,7 +594,7 @@ class State(containers.NamedTuple):
beta: Array
vk: Array

def init(init_vec: Array) -> State:
def init(init_vec: Array, *, nrows, ncols) -> State:
alphas = np.zeros((depth + 1,))
betas = np.zeros((depth + 1,))
Us = np.zeros((depth + 1, nrows))
Expand Down
8 changes: 3 additions & 5 deletions matfree/eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@

from matfree import decomp
from matfree.backend import linalg
from matfree.backend.typing import Array, Callable, Tuple
from matfree.backend.typing import Array, Callable


# todo: why does this function not return a callable?
def svd_partial(
v0: Array, depth: int, Av: Callable, vA: Callable, matrix_shape: Tuple[int, ...]
):
def svd_partial(v0: Array, depth: int, Av: Callable, vA: Callable):
"""Partial singular value decomposition.
Combines bidiagonalisation with full reorthogonalisation
Expand All @@ -30,7 +28,7 @@ def svd_partial(
Shape of the matrix involved in matrix-vector and vector-matrix products.
"""
# Factorise the matrix
algorithm = decomp.bidiag(depth, matrix_shape=matrix_shape, materialize=True)
algorithm = decomp.bidiag(depth, materialize=True)
(u, v), B, *_ = algorithm(Av, vA, v0)

# Compute SVD of factorisation
Expand Down
24 changes: 9 additions & 15 deletions matfree/funm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
"""

from matfree import decomp
from matfree.backend import containers, control_flow, func, linalg, np, tree_util
from matfree.backend.typing import Array, Callable

Expand Down Expand Up @@ -232,38 +231,34 @@ def matvec_flat(v_flat, *p):
return quadform


# todo: expect bidiag() to be passed here
def integrand_funm_product_logdet(depth, /):
def integrand_funm_product_logdet(bidiag: Callable, /):
r"""Construct the integrand for the log-determinant of a matrix-product.
Here, "product" refers to $X = A^\top A$.
"""
return integrand_funm_product(np.log, depth)
dense_funm = dense_funm_product_svd(np.log)
return integrand_funm_product(dense_funm, bidiag)


# todo: expect bidiag() to be passed here
def integrand_funm_product_schatten_norm(power, depth, /):
def integrand_funm_product_schatten_norm(power, bidiag: Callable, /):
r"""Construct the integrand for the $p$-th power of the Schatten-p norm."""

def matfun(x):
"""Matrix-function for Schatten-p norms."""
return x ** (power / 2)

return integrand_funm_product(matfun, depth)
dense_funm = dense_funm_product_svd(matfun)
return integrand_funm_product(dense_funm, bidiag)


# todo: expect bidiag() to be passed here
# todo: expect dense_funm_svd() to be passed here
def integrand_funm_product(matfun, depth, /):
def integrand_funm_product(dense_funm, algorithm, /):
r"""Construct the integrand for matrix-function-trace estimation.
Instead of the trace of a function of a matrix,
compute the trace of a function of the product of matrices.
Here, "product" refers to $X = A^\top A$.
"""

dense_funm = dense_funm_product_svd(matfun)

def quadform(matvecs, v0, *parameters):
matvec, vecmat = matvecs
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)
Expand All @@ -277,8 +272,6 @@ def matvec_flat(v_flat, *p):
return flat, tree_util.partial_pytree(unflatten)

w0_flat, w_unflatten = func.eval_shape(matvec_flat, v0_flat)
matrix_shape = (*np.shape(w0_flat), *np.shape(v0_flat))
algorithm = decomp.bidiag(depth, matrix_shape=matrix_shape, materialize=True)

def vecmat_flat(w_flat):
w = w_unflatten(w_flat)
Expand All @@ -298,6 +291,8 @@ def vecmat_flat(w_flat):


def dense_funm_product_svd(matfun):
"""Implement dense matrix-functions of a product of matrices via SVDs."""

def dense_funm(matrix, /):
# Compute SVD of factorisation
_, S, Vt = linalg.svd(matrix, full_matrices=False)
Expand All @@ -311,7 +306,6 @@ def dense_funm(matrix, /):
return dense_funm


# todo: rename to *_eigh_sym
def dense_funm_sym_eigh(matfun):
"""Implement dense matrix-functions via symmetric eigendecompositions.
Expand Down
10 changes: 6 additions & 4 deletions tests/test_decomp/test_bidiag.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def Av(v):
def vA(v):
return v @ A

algorithm = decomp.bidiag(order, matrix_shape=np.shape(A), materialize=True)
algorithm = decomp.bidiag(order, materialize=True)
(U, V), B, res, ln = algorithm(Av, vA, v0)

test_util.assert_columns_orthonormal(U)
Expand All @@ -52,7 +52,8 @@ def test_error_too_high_depth(nrows, ncols, num_significant_singular_vals):
max_depth = min(nrows, ncols) - 1

with testing.raises(ValueError, match=""):
_ = decomp.bidiag(max_depth + 1, matrix_shape=np.shape(A), materialize=False)
alg = decomp.bidiag(max_depth + 1, materialize=False)
_ = alg(lambda v: A @ v, lambda v: A.T @ v, A[0])


@testing.parametrize("nrows", [5])
Expand All @@ -63,7 +64,8 @@ def test_error_too_low_depth(nrows, ncols, num_significant_singular_vals):
A = make_A(nrows, ncols, num_significant_singular_vals)
min_depth = 0
with testing.raises(ValueError, match=""):
_ = decomp.bidiag(min_depth - 1, matrix_shape=np.shape(A), materialize=False)
alg = decomp.bidiag(min_depth - 1, materialize=False)
_ = alg(lambda v: A @ v, lambda v: A.T @ v, A[0])


@testing.parametrize("nrows", [15])
Expand All @@ -81,7 +83,7 @@ def Av(v):
def vA(v):
return v @ A

algorithm = decomp.bidiag(0, matrix_shape=np.shape(A), materialize=False)
algorithm = decomp.bidiag(0, materialize=False)
(U, V), (d_m, e_m), res, ln = algorithm(Av, vA, v0)

assert np.shape(U) == (nrows, 1)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_eig/test_svd_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def vA(v):

v0 = np.ones((ncols,))
v0 /= linalg.vector_norm(v0)
U, S, Vt = eig.svd_partial(v0, depth, Av, vA, matrix_shape=np.shape(A))
U, S, Vt = eig.svd_partial(v0, depth, Av, vA)
U_, S_, Vt_ = linalg.svd(A, full_matrices=False)

tols_decomp = {"atol": 1e-5, "rtol": 1e-5}
Expand Down
16 changes: 9 additions & 7 deletions tests/test_funm/test_integrand_funm_product_logdet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""Test stochastic Lanczos quadrature for log-determinants of matrix-products."""

from matfree import funm, stochtrace, test_util
from matfree import decomp, funm, stochtrace, test_util
from matfree.backend import linalg, np, prng, testing


@testing.fixture()
def A(nrows, ncols, num_significant_singular_vals):
def make_A(nrows, ncols, num_significant_singular_vals):
"""Make a positive definite matrix with certain spectrum."""
# 'Invent' a spectrum. Use the number of pre-defined eigenvalues.
n = min(nrows, ncols)
Expand All @@ -18,9 +17,9 @@ def A(nrows, ncols, num_significant_singular_vals):
@testing.parametrize("ncols", [30])
@testing.parametrize("num_significant_singular_vals", [30])
@testing.parametrize("order", [20])
def test_logdet_product(A, order):
def test_logdet_product(nrows, ncols, num_significant_singular_vals, order):
"""Assert that logdet_product yields an accurate estimate."""
_, ncols = np.shape(A)
A = make_A(nrows, ncols, num_significant_singular_vals)
key = prng.prng_key(3)

def matvec(x):
Expand All @@ -31,7 +30,9 @@ def vecmat(x):

x_like = {"fx": np.ones((ncols,), dtype=float)}
fun = stochtrace.sampler_normal(x_like, num=400)
problem = funm.integrand_funm_product_logdet(order)

bidiag = decomp.bidiag(order)
problem = funm.integrand_funm_product_logdet(bidiag)
estimate = stochtrace.estimator(problem, fun)
received = estimate((matvec, vecmat), key)

Expand All @@ -53,7 +54,8 @@ def test_logdet_product_exact_for_full_order_lanczos(n):

# Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm
order = n - 1
integrand = funm.integrand_funm_product_logdet(order)
bidiag = decomp.bidiag(order)
integrand = funm.integrand_funm_product_logdet(bidiag)

# Construct a vector without that does not have expected 2-norm equal to "dim"
x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 1
Expand Down
11 changes: 6 additions & 5 deletions tests/test_funm/test_integrand_funm_product_schatten_norm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
"""Test stochastic Lanczos quadrature for Schatten-p-norms."""

from matfree import funm, stochtrace, test_util
from matfree import decomp, funm, stochtrace, test_util
from matfree.backend import linalg, np, prng, testing


@testing.fixture()
def A(nrows, ncols, num_significant_singular_vals):
def make_A(nrows, ncols, num_significant_singular_vals):
"""Make a positive definite matrix with certain spectrum."""
# 'Invent' a spectrum. Use the number of pre-defined eigenvalues.
n = min(nrows, ncols)
Expand All @@ -19,15 +18,17 @@ def A(nrows, ncols, num_significant_singular_vals):
@testing.parametrize("num_significant_singular_vals", [30])
@testing.parametrize("order", [20])
@testing.parametrize("power", [1, 2, 5])
def test_schatten_norm(A, order, power):
def test_schatten_norm(nrows, ncols, num_significant_singular_vals, order, power):
"""Assert that the Schatten norm is accurate."""
A = make_A(nrows, ncols, num_significant_singular_vals)
_, s, _ = linalg.svd(A, full_matrices=False)
expected = np.sum(s**power)

_, ncols = np.shape(A)
args_like = np.ones((ncols,), dtype=float)
sampler = stochtrace.sampler_normal(args_like, num=500)
integrand = funm.integrand_funm_product_schatten_norm(power, order)
bidiag = decomp.bidiag(order)
integrand = funm.integrand_funm_product_schatten_norm(power, bidiag)
estimate = stochtrace.estimator(integrand, sampler)

key = prng.prng_key(1)
Expand Down
3 changes: 2 additions & 1 deletion tutorials/1_log_determinants.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def vecmat_l(x):


order = 3
problem = funm.integrand_funm_product_logdet(order)
bidiag = decomp.bidiag(order)
problem = funm.integrand_funm_product_logdet(bidiag)
sampler = stochtrace.sampler_normal(x_like, num=1_000)
estimator = stochtrace.estimator(problem, sampler=sampler)
logdet = estimator((matvec_r, vecmat_l), jax.random.PRNGKey(1))
Expand Down

0 comments on commit d7f1564

Please sign in to comment.