From f64f901ffd06acb8aceb28d879a6d775debff21b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Wed, 8 Jan 2025 12:03:56 +0100 Subject: [PATCH] Make the partial SVD compatible with Pytrees (#236) * Make eig.svd_partial compatible with Pytrees * Improve the documentation of the partial-SVD tests * Delete a redundant test --- matfree/backend/np.py | 5 +++++ matfree/backend/tree.py | 8 ++++++-- matfree/eig.py | 29 +++++++++++++++++++++++++++-- tests/test_eig/test_svd_partial.py | 30 ++++++++++++++++++++++++------ 4 files changed, 62 insertions(+), 10 deletions(-) diff --git a/matfree/backend/np.py b/matfree/backend/np.py index f8e1379..9cc776f 100644 --- a/matfree/backend/np.py +++ b/matfree/backend/np.py @@ -17,6 +17,7 @@ """ import jax.numpy as jnp +import jax.scipy.signal # Creation functions: @@ -189,6 +190,10 @@ def convolve(a, b, /, mode="full"): return jnp.convolve(a, b, mode=mode) +def convolve2d(a, b, /): + return jax.scipy.signal.convolve2d(a, b) + + def tril(x, /, shift=0): return jnp.tril(x, shift) diff --git a/matfree/backend/tree.py b/matfree/backend/tree.py index ddc786b..1a4e211 100644 --- a/matfree/backend/tree.py +++ b/matfree/backend/tree.py @@ -13,7 +13,11 @@ def tree_map(func, pytree, *rest): def ravel_pytree(pytree, /): - return jax.flatten_util.ravel_pytree(pytree) + ravelled, unravel = jax.flatten_util.ravel_pytree(pytree) + + # Wrap through a partial_pytree() to make ravel_pytree + # compatible with jax.eval_shape + return ravelled, partial_pytree(unravel) def tree_leaves(pytree, /): @@ -25,4 +29,4 @@ def tree_structure(pytree, /): def partial_pytree(func, /): - return jax.tree.Partial(func) + return jax.tree_util.Partial(func) diff --git a/matfree/eig.py b/matfree/eig.py index 0c40b37..b32aa37 100644 --- a/matfree/eig.py +++ b/matfree/eig.py @@ -1,6 +1,6 @@ """Matrix-free eigenvalue and singular-value analysis.""" -from matfree.backend import linalg +from matfree.backend import func, linalg, tree from matfree.backend.typing import Array, Callable @@ -21,8 +21,33 @@ def svd_partial(bidiag: Callable) -> Callable: """ def svd(Av: Callable, v0: Array, *parameters): + def Av_p(v): + return Av(v, *parameters) + + # Evaluate the output shape of Av and flatten + u0 = func.eval_shape(Av_p, v0) + + # Flatten in- and outputs + _, u_unravel = func.eval_shape(tree.ravel_pytree, u0) + v0_flat, v_unravel = tree.ravel_pytree(v0) + + def Av_flat(v_flat): + """Evaluate a flattened matvec.""" + result = Av_p(v_unravel(v_flat)) + result_flat, _ = tree.ravel_pytree(result) + return result_flat + + # Call the flattened SVD + ut, s, vt = svd_flat(Av_flat, v0_flat) + + # Select out_axes to ensure consistency with U @ diag(s) @ Vh + ut_tree = func.vmap(u_unravel)(ut) + vt_tree = func.vmap(v_unravel)(vt) + return ut_tree, s, vt_tree + + def svd_flat(Av: Callable, v0: Array): # Factorise the matrix - (u, v), B, *_ = bidiag(Av, v0, *parameters) + (u, v), B, *_ = bidiag(Av, v0) # Compute SVD of factorisation U, S, Vt = linalg.svd(B, full_matrices=False) diff --git a/tests/test_eig/test_svd_partial.py b/tests/test_eig/test_svd_partial.py index ac9deec..59bc9c5 100644 --- a/tests/test_eig/test_svd_partial.py +++ b/tests/test_eig/test_svd_partial.py @@ -32,12 +32,7 @@ def test_equal_to_linalg_svd(nrows, ncols): @testing.parametrize("nrows", [10]) @testing.parametrize("ncols", [3]) @testing.parametrize("num_matvecs", [0, 2, 3]) -def test_shapes_as_expected(nrows, ncols, num_matvecs): - """The output of full-depth SVD should be equal (*) to linalg.svd(). - - (*) Note: The singular values should be identical, - and the orthogonal matrices should be orthogonal. They are not unique. - """ +def test_shapes_as_expected_vectors(nrows, ncols, num_matvecs): d = np.arange(1.0, 1.0 + min(nrows, ncols)) A = test_util.asymmetric_matrix_from_singular_values(d, nrows=nrows, ncols=ncols) v0 = np.ones((ncols,)) @@ -49,3 +44,26 @@ def test_shapes_as_expected(nrows, ncols, num_matvecs): assert Ut.shape == (num_matvecs, nrows) assert S.shape == (num_matvecs,) assert Vt.shape == (num_matvecs, ncols) + + +@testing.parametrize("nrows", [10]) +@testing.parametrize("num_matvecs", [0, 2, 3]) +def test_shapes_as_expected_lists_tuples(nrows, num_matvecs): + K = np.arange(1.0, 10.0).reshape((3, 3)) + v0 = np.ones((nrows, nrows)) # tensor-valued input + + # Map Pytrees to Pytrees + def Av(v: tuple, stencil) -> list: + (x,) = v + return [np.convolve2d(stencil, x)] + + bidiag = decomp.bidiag(num_matvecs) + svd = eig.svd_partial(bidiag) + [Ut], S, (Vt,) = svd(Av, (v0,), K) + + # Ut inherits pytree-shape from the outputs (list), + # and Vt inherits pytree-shape from the inputs (tuple) + [u0] = Av((v0,), K) + assert Ut.shape == (num_matvecs, *u0.shape) + assert S.shape == (num_matvecs,) + assert Vt.shape == (num_matvecs, *v0.shape)