Skip to content

Commit

Permalink
Make the partial SVD compatible with Pytrees (#236)
Browse files Browse the repository at this point in the history
* Make eig.svd_partial compatible with Pytrees

* Improve the documentation of the partial-SVD tests

* Delete a redundant test
  • Loading branch information
pnkraemer authored Jan 8, 2025
1 parent 524474c commit f64f901
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 10 deletions.
5 changes: 5 additions & 0 deletions matfree/backend/np.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

import jax.numpy as jnp
import jax.scipy.signal

# Creation functions:

Expand Down Expand Up @@ -189,6 +190,10 @@ def convolve(a, b, /, mode="full"):
return jnp.convolve(a, b, mode=mode)


def convolve2d(a, b, /):
return jax.scipy.signal.convolve2d(a, b)


def tril(x, /, shift=0):
return jnp.tril(x, shift)

Expand Down
8 changes: 6 additions & 2 deletions matfree/backend/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ def tree_map(func, pytree, *rest):


def ravel_pytree(pytree, /):
return jax.flatten_util.ravel_pytree(pytree)
ravelled, unravel = jax.flatten_util.ravel_pytree(pytree)

# Wrap through a partial_pytree() to make ravel_pytree
# compatible with jax.eval_shape
return ravelled, partial_pytree(unravel)


def tree_leaves(pytree, /):
Expand All @@ -25,4 +29,4 @@ def tree_structure(pytree, /):


def partial_pytree(func, /):
return jax.tree.Partial(func)
return jax.tree_util.Partial(func)
29 changes: 27 additions & 2 deletions matfree/eig.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Matrix-free eigenvalue and singular-value analysis."""

from matfree.backend import linalg
from matfree.backend import func, linalg, tree
from matfree.backend.typing import Array, Callable


Expand All @@ -21,8 +21,33 @@ def svd_partial(bidiag: Callable) -> Callable:
"""

def svd(Av: Callable, v0: Array, *parameters):
def Av_p(v):
return Av(v, *parameters)

# Evaluate the output shape of Av and flatten
u0 = func.eval_shape(Av_p, v0)

# Flatten in- and outputs
_, u_unravel = func.eval_shape(tree.ravel_pytree, u0)
v0_flat, v_unravel = tree.ravel_pytree(v0)

def Av_flat(v_flat):
"""Evaluate a flattened matvec."""
result = Av_p(v_unravel(v_flat))
result_flat, _ = tree.ravel_pytree(result)
return result_flat

# Call the flattened SVD
ut, s, vt = svd_flat(Av_flat, v0_flat)

# Select out_axes to ensure consistency with U @ diag(s) @ Vh
ut_tree = func.vmap(u_unravel)(ut)
vt_tree = func.vmap(v_unravel)(vt)
return ut_tree, s, vt_tree

def svd_flat(Av: Callable, v0: Array):
# Factorise the matrix
(u, v), B, *_ = bidiag(Av, v0, *parameters)
(u, v), B, *_ = bidiag(Av, v0)

# Compute SVD of factorisation
U, S, Vt = linalg.svd(B, full_matrices=False)
Expand Down
30 changes: 24 additions & 6 deletions tests/test_eig/test_svd_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,7 @@ def test_equal_to_linalg_svd(nrows, ncols):
@testing.parametrize("nrows", [10])
@testing.parametrize("ncols", [3])
@testing.parametrize("num_matvecs", [0, 2, 3])
def test_shapes_as_expected(nrows, ncols, num_matvecs):
"""The output of full-depth SVD should be equal (*) to linalg.svd().
(*) Note: The singular values should be identical,
and the orthogonal matrices should be orthogonal. They are not unique.
"""
def test_shapes_as_expected_vectors(nrows, ncols, num_matvecs):
d = np.arange(1.0, 1.0 + min(nrows, ncols))
A = test_util.asymmetric_matrix_from_singular_values(d, nrows=nrows, ncols=ncols)
v0 = np.ones((ncols,))
Expand All @@ -49,3 +44,26 @@ def test_shapes_as_expected(nrows, ncols, num_matvecs):
assert Ut.shape == (num_matvecs, nrows)
assert S.shape == (num_matvecs,)
assert Vt.shape == (num_matvecs, ncols)


@testing.parametrize("nrows", [10])
@testing.parametrize("num_matvecs", [0, 2, 3])
def test_shapes_as_expected_lists_tuples(nrows, num_matvecs):
K = np.arange(1.0, 10.0).reshape((3, 3))
v0 = np.ones((nrows, nrows)) # tensor-valued input

# Map Pytrees to Pytrees
def Av(v: tuple, stencil) -> list:
(x,) = v
return [np.convolve2d(stencil, x)]

bidiag = decomp.bidiag(num_matvecs)
svd = eig.svd_partial(bidiag)
[Ut], S, (Vt,) = svd(Av, (v0,), K)

# Ut inherits pytree-shape from the outputs (list),
# and Vt inherits pytree-shape from the inputs (tuple)
[u0] = Av((v0,), K)
assert Ut.shape == (num_matvecs, *u0.shape)
assert S.shape == (num_matvecs,)
assert Vt.shape == (num_matvecs, *v0.shape)

0 comments on commit f64f901

Please sign in to comment.