diff --git a/matfree/decomp.py b/matfree/decomp.py index f7147f3..5b44e4f 100644 --- a/matfree/decomp.py +++ b/matfree/decomp.py @@ -535,7 +535,7 @@ def _extract_diag(x, offset=0): return linalg.diagonal_matrix(diag, offset=offset) -def bidiag(depth: int, /, matrix_shape, materialize: bool = True): +def bidiag(depth: int, /, materialize: bool = True): """Construct an implementation of **bidiagonalisation**. Uses pre-allocation and full reorthogonalisation. @@ -553,17 +553,23 @@ def bidiag(depth: int, /, matrix_shape, materialize: bool = True): consider using [tridiag_sym][matfree.decomp.tridiag_sym] for the time being. """ - nrows, ncols = matrix_shape - max_depth = min(nrows, ncols) - 1 - if depth > max_depth or depth < 0: - msg1 = f"Depth {depth} exceeds the matrix' dimensions. " - msg2 = f"Expected: 0 <= depth <= min(nrows, ncols) - 1 = {max_depth} " - msg3 = f"for a matrix with shape {matrix_shape}." - raise ValueError(msg1 + msg2 + msg3) def estimate(Av: Callable, vA: Callable, v0, *parameters): + # Infer the size of A from v0 + (ncols,) = np.shape(v0) + w0_like = func.eval_shape(Av, v0) + (nrows,) = np.shape(w0_like) + + # Complain if the shapes don't match + max_depth = min(nrows, ncols) - 1 + if depth > max_depth or depth < 0: + msg1 = f"Depth {depth} exceeds the matrix' dimensions. " + msg2 = f"Expected: 0 <= depth <= min(nrows, ncols) - 1 = {max_depth} " + msg3 = f"for a matrix with shape {(nrows, ncols)}." + raise ValueError(msg1 + msg2 + msg3) + v0_norm, length = _normalise(v0) - init_val = init(v0_norm) + init_val = init(v0_norm, nrows=nrows, ncols=ncols) def body_fun(_, s): return step(Av, vA, s, *parameters) @@ -588,7 +594,7 @@ class State(containers.NamedTuple): beta: Array vk: Array - def init(init_vec: Array) -> State: + def init(init_vec: Array, *, nrows, ncols) -> State: alphas = np.zeros((depth + 1,)) betas = np.zeros((depth + 1,)) Us = np.zeros((depth + 1, nrows)) diff --git a/matfree/eig.py b/matfree/eig.py index 6b1dc25..a56a9c9 100644 --- a/matfree/eig.py +++ b/matfree/eig.py @@ -2,13 +2,11 @@ from matfree import decomp from matfree.backend import linalg -from matfree.backend.typing import Array, Callable, Tuple +from matfree.backend.typing import Array, Callable # todo: why does this function not return a callable? -def svd_partial( - v0: Array, depth: int, Av: Callable, vA: Callable, matrix_shape: Tuple[int, ...] -): +def svd_partial(v0: Array, depth: int, Av: Callable, vA: Callable): """Partial singular value decomposition. Combines bidiagonalisation with full reorthogonalisation @@ -30,7 +28,7 @@ def svd_partial( Shape of the matrix involved in matrix-vector and vector-matrix products. """ # Factorise the matrix - algorithm = decomp.bidiag(depth, matrix_shape=matrix_shape, materialize=True) + algorithm = decomp.bidiag(depth, materialize=True) (u, v), B, *_ = algorithm(Av, vA, v0) # Compute SVD of factorisation diff --git a/matfree/funm.py b/matfree/funm.py index e0a1997..59d59c7 100644 --- a/matfree/funm.py +++ b/matfree/funm.py @@ -37,7 +37,6 @@ """ -from matfree import decomp from matfree.backend import containers, control_flow, func, linalg, np, tree_util from matfree.backend.typing import Array, Callable @@ -232,29 +231,27 @@ def matvec_flat(v_flat, *p): return quadform -# todo: expect bidiag() to be passed here -def integrand_funm_product_logdet(depth, /): +def integrand_funm_product_logdet(bidiag: Callable, /): r"""Construct the integrand for the log-determinant of a matrix-product. Here, "product" refers to $X = A^\top A$. """ - return integrand_funm_product(np.log, depth) + dense_funm = dense_funm_product_svd(np.log) + return integrand_funm_product(dense_funm, bidiag) -# todo: expect bidiag() to be passed here -def integrand_funm_product_schatten_norm(power, depth, /): +def integrand_funm_product_schatten_norm(power, bidiag: Callable, /): r"""Construct the integrand for the $p$-th power of the Schatten-p norm.""" def matfun(x): """Matrix-function for Schatten-p norms.""" return x ** (power / 2) - return integrand_funm_product(matfun, depth) + dense_funm = dense_funm_product_svd(matfun) + return integrand_funm_product(dense_funm, bidiag) -# todo: expect bidiag() to be passed here -# todo: expect dense_funm_svd() to be passed here -def integrand_funm_product(matfun, depth, /): +def integrand_funm_product(dense_funm, algorithm, /): r"""Construct the integrand for matrix-function-trace estimation. Instead of the trace of a function of a matrix, @@ -262,8 +259,6 @@ def integrand_funm_product(matfun, depth, /): Here, "product" refers to $X = A^\top A$. """ - dense_funm = dense_funm_product_svd(matfun) - def quadform(matvecs, v0, *parameters): matvec, vecmat = matvecs v0_flat, v_unflatten = tree_util.ravel_pytree(v0) @@ -277,8 +272,6 @@ def matvec_flat(v_flat, *p): return flat, tree_util.partial_pytree(unflatten) w0_flat, w_unflatten = func.eval_shape(matvec_flat, v0_flat) - matrix_shape = (*np.shape(w0_flat), *np.shape(v0_flat)) - algorithm = decomp.bidiag(depth, matrix_shape=matrix_shape, materialize=True) def vecmat_flat(w_flat): w = w_unflatten(w_flat) @@ -298,6 +291,8 @@ def vecmat_flat(w_flat): def dense_funm_product_svd(matfun): + """Implement dense matrix-functions of a product of matrices via SVDs.""" + def dense_funm(matrix, /): # Compute SVD of factorisation _, S, Vt = linalg.svd(matrix, full_matrices=False) @@ -311,7 +306,6 @@ def dense_funm(matrix, /): return dense_funm -# todo: rename to *_eigh_sym def dense_funm_sym_eigh(matfun): """Implement dense matrix-functions via symmetric eigendecompositions. diff --git a/tests/test_decomp/test_bidiag.py b/tests/test_decomp/test_bidiag.py index 7305e08..f01ae5d 100644 --- a/tests/test_decomp/test_bidiag.py +++ b/tests/test_decomp/test_bidiag.py @@ -31,7 +31,7 @@ def Av(v): def vA(v): return v @ A - algorithm = decomp.bidiag(order, matrix_shape=np.shape(A), materialize=True) + algorithm = decomp.bidiag(order, materialize=True) (U, V), B, res, ln = algorithm(Av, vA, v0) test_util.assert_columns_orthonormal(U) @@ -52,7 +52,8 @@ def test_error_too_high_depth(nrows, ncols, num_significant_singular_vals): max_depth = min(nrows, ncols) - 1 with testing.raises(ValueError, match=""): - _ = decomp.bidiag(max_depth + 1, matrix_shape=np.shape(A), materialize=False) + alg = decomp.bidiag(max_depth + 1, materialize=False) + _ = alg(lambda v: A @ v, lambda v: A.T @ v, A[0]) @testing.parametrize("nrows", [5]) @@ -63,7 +64,8 @@ def test_error_too_low_depth(nrows, ncols, num_significant_singular_vals): A = make_A(nrows, ncols, num_significant_singular_vals) min_depth = 0 with testing.raises(ValueError, match=""): - _ = decomp.bidiag(min_depth - 1, matrix_shape=np.shape(A), materialize=False) + alg = decomp.bidiag(min_depth - 1, materialize=False) + _ = alg(lambda v: A @ v, lambda v: A.T @ v, A[0]) @testing.parametrize("nrows", [15]) @@ -81,7 +83,7 @@ def Av(v): def vA(v): return v @ A - algorithm = decomp.bidiag(0, matrix_shape=np.shape(A), materialize=False) + algorithm = decomp.bidiag(0, materialize=False) (U, V), (d_m, e_m), res, ln = algorithm(Av, vA, v0) assert np.shape(U) == (nrows, 1) diff --git a/tests/test_eig/test_svd_partial.py b/tests/test_eig/test_svd_partial.py index 655ffce..c852e02 100644 --- a/tests/test_eig/test_svd_partial.py +++ b/tests/test_eig/test_svd_partial.py @@ -34,7 +34,7 @@ def vA(v): v0 = np.ones((ncols,)) v0 /= linalg.vector_norm(v0) - U, S, Vt = eig.svd_partial(v0, depth, Av, vA, matrix_shape=np.shape(A)) + U, S, Vt = eig.svd_partial(v0, depth, Av, vA) U_, S_, Vt_ = linalg.svd(A, full_matrices=False) tols_decomp = {"atol": 1e-5, "rtol": 1e-5} diff --git a/tests/test_funm/test_integrand_funm_product_logdet.py b/tests/test_funm/test_integrand_funm_product_logdet.py index bf1fcbf..c3ceda0 100644 --- a/tests/test_funm/test_integrand_funm_product_logdet.py +++ b/tests/test_funm/test_integrand_funm_product_logdet.py @@ -1,11 +1,10 @@ """Test stochastic Lanczos quadrature for log-determinants of matrix-products.""" -from matfree import funm, stochtrace, test_util +from matfree import decomp, funm, stochtrace, test_util from matfree.backend import linalg, np, prng, testing -@testing.fixture() -def A(nrows, ncols, num_significant_singular_vals): +def make_A(nrows, ncols, num_significant_singular_vals): """Make a positive definite matrix with certain spectrum.""" # 'Invent' a spectrum. Use the number of pre-defined eigenvalues. n = min(nrows, ncols) @@ -18,9 +17,9 @@ def A(nrows, ncols, num_significant_singular_vals): @testing.parametrize("ncols", [30]) @testing.parametrize("num_significant_singular_vals", [30]) @testing.parametrize("order", [20]) -def test_logdet_product(A, order): +def test_logdet_product(nrows, ncols, num_significant_singular_vals, order): """Assert that logdet_product yields an accurate estimate.""" - _, ncols = np.shape(A) + A = make_A(nrows, ncols, num_significant_singular_vals) key = prng.prng_key(3) def matvec(x): @@ -31,7 +30,9 @@ def vecmat(x): x_like = {"fx": np.ones((ncols,), dtype=float)} fun = stochtrace.sampler_normal(x_like, num=400) - problem = funm.integrand_funm_product_logdet(order) + + bidiag = decomp.bidiag(order) + problem = funm.integrand_funm_product_logdet(bidiag) estimate = stochtrace.estimator(problem, fun) received = estimate((matvec, vecmat), key) @@ -53,7 +54,8 @@ def test_logdet_product_exact_for_full_order_lanczos(n): # Set up max-order Lanczos approximation inside SLQ for the matrix-logarithm order = n - 1 - integrand = funm.integrand_funm_product_logdet(order) + bidiag = decomp.bidiag(order) + integrand = funm.integrand_funm_product_logdet(bidiag) # Construct a vector without that does not have expected 2-norm equal to "dim" x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 1 diff --git a/tests/test_funm/test_integrand_funm_product_schatten_norm.py b/tests/test_funm/test_integrand_funm_product_schatten_norm.py index 4ce687a..34a9579 100644 --- a/tests/test_funm/test_integrand_funm_product_schatten_norm.py +++ b/tests/test_funm/test_integrand_funm_product_schatten_norm.py @@ -1,11 +1,10 @@ """Test stochastic Lanczos quadrature for Schatten-p-norms.""" -from matfree import funm, stochtrace, test_util +from matfree import decomp, funm, stochtrace, test_util from matfree.backend import linalg, np, prng, testing -@testing.fixture() -def A(nrows, ncols, num_significant_singular_vals): +def make_A(nrows, ncols, num_significant_singular_vals): """Make a positive definite matrix with certain spectrum.""" # 'Invent' a spectrum. Use the number of pre-defined eigenvalues. n = min(nrows, ncols) @@ -19,15 +18,17 @@ def A(nrows, ncols, num_significant_singular_vals): @testing.parametrize("num_significant_singular_vals", [30]) @testing.parametrize("order", [20]) @testing.parametrize("power", [1, 2, 5]) -def test_schatten_norm(A, order, power): +def test_schatten_norm(nrows, ncols, num_significant_singular_vals, order, power): """Assert that the Schatten norm is accurate.""" + A = make_A(nrows, ncols, num_significant_singular_vals) _, s, _ = linalg.svd(A, full_matrices=False) expected = np.sum(s**power) _, ncols = np.shape(A) args_like = np.ones((ncols,), dtype=float) sampler = stochtrace.sampler_normal(args_like, num=500) - integrand = funm.integrand_funm_product_schatten_norm(power, order) + bidiag = decomp.bidiag(order) + integrand = funm.integrand_funm_product_schatten_norm(power, bidiag) estimate = stochtrace.estimator(integrand, sampler) key = prng.prng_key(1) diff --git a/tutorials/1_log_determinants.py b/tutorials/1_log_determinants.py index 3b9454e..92c52cd 100644 --- a/tutorials/1_log_determinants.py +++ b/tutorials/1_log_determinants.py @@ -59,7 +59,8 @@ def vecmat_l(x): order = 3 -problem = funm.integrand_funm_product_logdet(order) +bidiag = decomp.bidiag(order) +problem = funm.integrand_funm_product_logdet(bidiag) sampler = stochtrace.sampler_normal(x_like, num=1_000) estimator = stochtrace.estimator(problem, sampler=sampler) logdet = estimator((matvec_r, vecmat_l), jax.random.PRNGKey(1))