Skip to content

Commit

Permalink
Automatically implement vector matrix products with jax.vjp (#229)
Browse files Browse the repository at this point in the history
* Remove 'vA' from the API and use VJPs instead

* Autoupdate the pre-commit hook
  • Loading branch information
pnkraemer authored Jan 7, 2025
1 parent daf9ef1 commit 2bcba5f
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 47 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@ default_language_version:
python: python3
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
rev: v5.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-merge-conflict
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.6.4
rev: v0.8.6
hooks:
# Run the linter.
- id: ruff
args: [--fix]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.11.2
rev: v1.14.1
hooks:
- id: mypy
args:
Expand Down
15 changes: 10 additions & 5 deletions matfree/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,8 @@ def bidiag(depth: int, /, materialize: bool = True):
Decompose a matrix into a product of orthogonal-**bidiagonal**-orthogonal matrices.
Use this algorithm for approximate **singular value** decompositions.
Internally, Matfree uses JAX to turn matrix-vector- into vector-matrix-products.
??? note "A note about differentiability"
Unlike [tridiag_sym][matfree.decomp.tridiag_sym] or
[hessenberg][matfree.decomp.hessenberg], this function's reverse-mode
Expand All @@ -592,7 +594,7 @@ def bidiag(depth: int, /, materialize: bool = True):
"""

def estimate(Av: Callable, vA: Callable, v0, *parameters):
def estimate(Av: Callable, v0, *parameters):
# Infer the size of A from v0
(ncols,) = np.shape(v0)
w0_like = func.eval_shape(Av, v0)
Expand All @@ -610,7 +612,7 @@ def estimate(Av: Callable, vA: Callable, v0, *parameters):
init_val = init(v0_norm, nrows=nrows, ncols=ncols)

def body_fun(_, s):
return step(Av, vA, s, *parameters)
return step(Av, s, *parameters)

result = control_flow.fori_loop(
0, depth + 1, body_fun=body_fun, init_val=init_val
Expand Down Expand Up @@ -640,18 +642,21 @@ def init(init_vec: Array, *, nrows, ncols) -> State:
v0, _ = _normalise(init_vec)
return State(0, Us, Vs, alphas, betas, np.zeros(()), v0)

def step(Av, vA, state: State, *parameters) -> State:
def step(Av, state: State, *parameters) -> State:
i, Us, Vs, alphas, betas, beta, vk = state
Vs = Vs.at[i].set(vk)
betas = betas.at[i].set(beta)

uk = Av(vk, *parameters) - beta * Us[i - 1]
# Use jax.vjp to evaluate the vector-matrix product
Av_eval, vA = func.vjp(lambda v: Av(v, *parameters), vk)
uk = Av_eval - beta * Us[i - 1]
uk, alpha = _normalise(uk)
uk, *_ = _gram_schmidt_classical(uk, Us) # full reorthogonalisation
Us = Us.at[i].set(uk)
alphas = alphas.at[i].set(alpha)

vk = vA(uk, *parameters) - alpha * vk
(vA_eval,) = vA(uk)
vk = vA_eval - alpha * vk
vk, beta = _normalise(vk)
vk, *_ = _gram_schmidt_classical(vk, Vs) # full reorthogonalisation

Expand Down
4 changes: 2 additions & 2 deletions matfree/eig.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


# todo: why does this function not return a callable?
def svd_partial(v0: Array, depth: int, Av: Callable, vA: Callable):
def svd_partial(v0: Array, depth: int, Av: Callable):
"""Partial singular value decomposition.
Combines bidiagonalisation with full reorthogonalisation
Expand All @@ -27,7 +27,7 @@ def svd_partial(v0: Array, depth: int, Av: Callable, vA: Callable):
"""
# Factorise the matrix
algorithm = decomp.bidiag(depth, materialize=True)
(u, v), B, *_ = algorithm(Av, vA, v0)
(u, v), B, *_ = algorithm(Av, v0)

# Compute SVD of factorisation
U, S, Vt = linalg.svd(B, full_matrices=False)
Expand Down
19 changes: 5 additions & 14 deletions matfree/funm.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,30 +259,21 @@ def integrand_funm_product(dense_funm, algorithm, /):
Here, "product" refers to $X = A^\top A$.
"""

def quadform(matvecs, v0, *parameters):
matvec, vecmat = matvecs
def quadform(matvec, v0, *parameters):
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)
length = linalg.vector_norm(v0_flat)
v0_flat /= length

def matvec_flat(v_flat, *p):
v = v_unflatten(v_flat)
Av = matvec(v, *p)
flat, unflatten = tree_util.ravel_pytree(Av)
return flat, tree_util.partial_pytree(unflatten)

w0_flat, w_unflatten = func.eval_shape(matvec_flat, v0_flat)

def vecmat_flat(w_flat):
w = w_unflatten(w_flat)
wA = vecmat(w, *parameters)
return tree_util.ravel_pytree(wA)[0]
flat, _unflatten = tree_util.ravel_pytree(Av)
return flat

# Decompose into orthogonal-bidiag-orthogonal
matvec_flat_p = lambda v: matvec_flat(v)[0] # noqa: E731
output = algorithm(matvec_flat_p, vecmat_flat, v0_flat, *parameters)
_, B, *_ = output
_, B, *_ = algorithm(matvec_flat, v0_flat, *parameters)

# Evaluate matfun
fA = dense_funm(B)
e1 = np.eye(len(fA))[0, :]
return length**2 * linalg.inner(e1, fA @ e1)
Expand Down
12 changes: 6 additions & 6 deletions tests/test_decomp/test_bidiag.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def vA(v):
return v @ A

algorithm = decomp.bidiag(order, materialize=True)
(U, V), B, res, ln = algorithm(Av, vA, v0)
(U, V), B, res, ln = algorithm(Av, v0)

test_util.assert_columns_orthonormal(U)
test_util.assert_columns_orthonormal(V)
Expand All @@ -51,9 +51,9 @@ def test_error_too_high_depth(nrows, ncols, num_significant_singular_vals):
A = make_A(nrows, ncols, num_significant_singular_vals)
max_depth = min(nrows, ncols) - 1

with testing.raises(ValueError, match=""):
with testing.raises(ValueError, match="exceeds"):
alg = decomp.bidiag(max_depth + 1, materialize=False)
_ = alg(lambda v: A @ v, lambda v: A.T @ v, A[0])
_ = alg(lambda v: A @ v, A[0])


@testing.parametrize("nrows", [5])
Expand All @@ -63,9 +63,9 @@ def test_error_too_low_depth(nrows, ncols, num_significant_singular_vals):
"""Assert that a ValueError is raised when the depth is negative."""
A = make_A(nrows, ncols, num_significant_singular_vals)
min_depth = 0
with testing.raises(ValueError, match=""):
with testing.raises(ValueError, match="exceeds"):
alg = decomp.bidiag(min_depth - 1, materialize=False)
_ = alg(lambda v: A @ v, lambda v: A.T @ v, A[0])
_ = alg(lambda v: A @ v, A[0])


@testing.parametrize("nrows", [15])
Expand All @@ -84,7 +84,7 @@ def vA(v):
return v @ A

algorithm = decomp.bidiag(0, materialize=False)
(U, V), (d_m, e_m), res, ln = algorithm(Av, vA, v0)
(U, V), (d_m, e_m), res, ln = algorithm(Av, v0)

assert np.shape(U) == (nrows, 1)
assert np.shape(V) == (ncols, 1)
Expand Down
5 changes: 1 addition & 4 deletions tests/test_eig/test_svd_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,9 @@ def test_equal_to_linalg_svd(A):
def Av(v):
return A @ v

def vA(v):
return v @ A

v0 = np.ones((ncols,))
v0 /= linalg.vector_norm(v0)
U, S, Vt = eig.svd_partial(v0, depth, Av, vA)
U, S, Vt = eig.svd_partial(v0, depth, Av)
U_, S_, Vt_ = linalg.svd(A, full_matrices=False)

tols_decomp = {"atol": 1e-5, "rtol": 1e-5}
Expand Down
7 changes: 2 additions & 5 deletions tests/test_funm/test_integrand_funm_product_logdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,13 @@ def test_logdet_product(nrows, ncols, num_significant_singular_vals, order):
def matvec(x):
return {"fx": A @ x["fx"]}

def vecmat(x):
return {"fx": x["fx"] @ A}

x_like = {"fx": np.ones((ncols,), dtype=float)}
fun = stochtrace.sampler_normal(x_like, num=400)

bidiag = decomp.bidiag(order)
problem = funm.integrand_funm_product_logdet(bidiag)
estimate = stochtrace.estimator(problem, fun)
received = estimate((matvec, vecmat), key)
received = estimate(matvec, key)

expected = linalg.slogdet(A.T @ A)[1]
print_if_assert_fails = ("error", np.abs(received - expected), "target:", expected)
Expand All @@ -61,7 +58,7 @@ def test_logdet_product_exact_for_full_order_lanczos(n):
x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 1

# Compute v^\top @ log(A) @ v via Lanczos
received = integrand((lambda v: A @ v, lambda v: v @ A), x)
received = integrand((lambda v: A @ v), x)

# Compute the "true" value of v^\top @ log(A) @ v via eigenvalues
eigvals, eigvecs = linalg.eigh(A.T @ A)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_schatten_norm(nrows, ncols, num_significant_singular_vals, order, power
estimate = stochtrace.estimator(integrand, sampler)

key = prng.prng_key(1)
received = estimate((lambda v: A @ v, lambda v: A.T @ v), key)
received = estimate(lambda v: A @ v, key)

print_if_assert_fails = ("error", np.abs(received - expected), "target:", expected)
assert np.allclose(received, expected, atol=1e-2, rtol=1e-2), print_if_assert_fails
12 changes: 5 additions & 7 deletions tutorials/1_log_determinants.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,22 @@ def matvec(x):
A /= nrows**2


def matvec_r(x):
def matvec_half(x):
"""Compute a matrix-vector product."""
return A @ x


def vecmat_l(x):
"""Compute a vector-matrix product."""
return x @ A


order = 3
bidiag = decomp.bidiag(order)
problem = funm.integrand_funm_product_logdet(bidiag)
sampler = stochtrace.sampler_normal(x_like, num=1_000)
estimator = stochtrace.estimator(problem, sampler=sampler)
logdet = estimator((matvec_r, vecmat_l), jax.random.PRNGKey(1))
logdet = estimator(matvec_half, jax.random.PRNGKey(1))
print(logdet)

# Internally, Matfree uses JAX's vector-Jacobian products to
# turn the matrix-vector product into a vector-matrix product.
#
# For comparison:

print(jnp.linalg.slogdet(A.T @ A))

0 comments on commit 2bcba5f

Please sign in to comment.