Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge montecarlo/hutchinson and decomp/lanczos modules #161

Merged
merged 4 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ clean:
rm -rf *.egg-info
rm -rf dist site build htmlcov
rm -rf *.ipynb_checkpoints
rm matfree/_version.py

doc:
mkdocs build
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Import matfree and JAX, and set up a test problem.
```python
>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutchinson, montecarlo, slq
>>> from matfree import hutchinson, slq

>>> A = jnp.reshape(jnp.arange(12.0), (6, 2))
>>>
Expand All @@ -65,7 +65,7 @@ Estimate the trace of the matrix:

```python
>>> key = jax.random.PRNGKey(1)
>>> normal = montecarlo.normal(shape=(2,))
>>> normal = hutchinson.normal(shape=(2,))
>>> trace = hutchinson.trace(matvec, key=key, sample_fun=normal)
>>>
>>> print(jnp.round(trace))
Expand Down
3 changes: 0 additions & 3 deletions docs/api/lanczos.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/api/montecarlo.md

This file was deleted.

4 changes: 2 additions & 2 deletions docs/benchmarks/control_variates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Runtime: ~10 seconds.
"""

from matfree import benchmark_util, hutchinson, montecarlo
from matfree import benchmark_util, hutchinson
from matfree.backend import func, linalg, np, plt, prng, progressbar


Expand All @@ -22,7 +22,7 @@ def f(x):
_, jvp = func.linearize(f, x0)
J = func.jacfwd(f)(x0)
trace = linalg.trace(J)
sample_fun = montecarlo.normal(shape=(n,), dtype=float)
sample_fun = hutchinson.normal(shape=(n,), dtype=float)

return (jvp, trace, J), (key, sample_fun)

Expand Down
4 changes: 2 additions & 2 deletions docs/benchmarks/jacobian_squared.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""What is the fastest way of computing trace(A^5)."""
from matfree import benchmark_util, hutchinson, montecarlo, slq
from matfree import benchmark_util, hutchinson, slq
from matfree.backend import func, linalg, np, plt, prng
from matfree.backend.progressbar import progressbar

Expand All @@ -20,7 +20,7 @@ def f(x):
J = func.jacfwd(f)(x0)
A = J @ J @ J @ J
trace = linalg.trace(A)
sample_fun = montecarlo.normal(shape=(n,), dtype=float)
sample_fun = hutchinson.normal(shape=(n,), dtype=float)

def Av(v):
return jvp(jvp(jvp(jvp(v))))
Expand Down
4 changes: 2 additions & 2 deletions docs/control_variates.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ Imports:
```python
>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutchinson, montecarlo, slq
>>> from matfree import hutchinson, slq

>>> a = jnp.reshape(jnp.arange(12.0), (6, 2))
>>> key = jax.random.PRNGKey(1)

>>> matvec = lambda x: a.T @ (a @ x)
>>> sample_fun = montecarlo.normal(shape=(2,))
>>> sample_fun = hutchinson.normal(shape=(2,))

```

Expand Down
8 changes: 4 additions & 4 deletions docs/higher_moments.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
```python
>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutchinson, montecarlo, slq
>>> from matfree import hutchinson, slq

>>> a = jnp.reshape(jnp.arange(12.0), (6, 2))
>>> key = jax.random.PRNGKey(1)

>>> mvp = lambda x: a.T @ (a @ x)
>>> sample_fun = montecarlo.normal(shape=(2,))
>>> sample_fun = hutchinson.normal(shape=(2,))

```

Expand All @@ -21,7 +21,7 @@ Compute them as such

```python
>>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36
>>> normal = montecarlo.normal(shape=(6,))
>>> normal = hutchinson.normal(shape=(6,))
>>> mvp = lambda x: a.T @ (a @ x) + x
>>> first, second = hutchinson.trace_moments(mvp, key=key, sample_fun=normal)
>>> print(jnp.round(first, 1))
Expand Down Expand Up @@ -53,7 +53,7 @@ Implement this as follows:

```python
>>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36
>>> sample_fun = montecarlo.normal(shape=(6,))
>>> sample_fun = hutchinson.normal(shape=(6,))
>>> num_samples = 10_000
>>> mvp = lambda x: a.T @ (a @ x) + x
>>> first, second = hutchinson.trace_moments(
Expand Down
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Import matfree and JAX, and set up a test problem.
```python
>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutchinson, montecarlo, slq
>>> from matfree import hutchinson, slq

>>> A = jnp.reshape(jnp.arange(12.0), (6, 2))
>>>
Expand All @@ -65,7 +65,7 @@ Estimate the trace of the matrix:

```python
>>> key = jax.random.PRNGKey(1)
>>> normal = montecarlo.normal(shape=(2,))
>>> normal = hutchinson.normal(shape=(2,))
>>> trace = hutchinson.trace(matvec, key=key, sample_fun=normal)
>>>
>>> print(jnp.round(trace))
Expand Down
8 changes: 4 additions & 4 deletions docs/log_determinants.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@ Imports:
```python
>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutchinson, montecarlo, slq
>>> from matfree import hutchinson, slq

>>> a = jnp.reshape(jnp.arange(12.0), (6, 2))
>>> key = jax.random.PRNGKey(1)

>>> matvec = lambda x: a.T @ (a @ x)
>>> sample_fun = montecarlo.normal(shape=(2,))
>>> sample_fun = hutchinson.normal(shape=(2,))

```


Estimate log-determinants as such:
```python
>>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36
>>> sample_fun = montecarlo.normal(shape=(6,))
>>> sample_fun = hutchinson.normal(shape=(6,))
>>> matvec = lambda x: a.T @ (a @ x) + x
>>> order = 3
>>> logdet = slq.logdet_spd(order, matvec, key=key, sample_fun=sample_fun)
Expand All @@ -37,7 +37,7 @@ on arithmetic with $B$; no need to assemble $M$:

```python
>>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36 + jnp.eye(6)
>>> sample_fun = montecarlo.normal(shape=(6,))
>>> sample_fun = hutchinson.normal(shape=(6,))
>>> matvec = lambda x: (a @ x)
>>> vecmat = lambda x: (a.T @ x)
>>> order = 3
Expand Down
4 changes: 2 additions & 2 deletions docs/pytree_logdeterminants.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Imports:
>>> import jax.flatten_util # this is important!
>>> import jax.numpy as jnp
>>>
>>> from matfree import slq, montecarlo
>>> from matfree import slq, hutchinson

```
Create a test-problem: a function that maps a pytree (dict) to a pytree (tuple).
Expand Down Expand Up @@ -70,7 +70,7 @@ Now, we can compute the log-determinant with the flattened inputs as usual:
```python
>>> # Compute the log-determinant
>>> key = jax.random.PRNGKey(seed=1)
>>> sample_fun = montecarlo.normal(shape=f0_flat.shape)
>>> sample_fun = hutchinson.normal(shape=f0_flat.shape)
>>> order = 3
>>> logdet = slq.logdet_spd(order, matvec, key=key, sample_fun=sample_fun)

Expand Down
4 changes: 2 additions & 2 deletions docs/vector_calculus.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Here is how we can implement divergences and Laplacians without forming full Jac
```python
>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutchinson, montecarlo
>>> from matfree import hutchinson

```

Expand Down Expand Up @@ -85,7 +85,7 @@ For large-scale problems, it may be the only way of computing Laplacians reliabl
```python
>>> laplacian_dense = divergence_dense(gradient)
>>>
>>> normal = montecarlo.normal(shape=(3,))
>>> normal = hutchinson.normal(shape=(3,))
>>> key = jax.random.PRNGKey(1)
>>> laplacian_matfree = divergence_matfree(gradient, key=key, sample_fun=normal)
>>>
Expand Down
164 changes: 160 additions & 4 deletions matfree/decomp.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,166 @@
"""Matrix decomposition algorithms."""

from matfree import lanczos
from matfree.backend import containers, control_flow, linalg
from matfree.backend import containers, control_flow, linalg, np
from matfree.backend.typing import Array, Callable, Tuple


class _Alg(containers.NamedTuple):
"""Matrix decomposition algorithm."""

init: Callable
"""Initialise the state of the algorithm. Usually, this involves pre-allocation."""

step: Callable
"""Compute the next iteration."""

extract: Callable
"""Extract the solution from the state of the algorithm."""

lower_upper: Tuple[int, int]
"""Range of the for-loop used to decompose a matrix."""


def tridiagonal_full_reortho(depth, /):
"""Construct an implementation of **tridiagonalisation**.

Uses pre-allocation. Fully reorthogonalise vectors at every step.

This algorithm assumes a **symmetric matrix**.

Decompose a matrix into a product of orthogonal-**tridiagonal**-orthogonal matrices.
Use this algorithm for approximate **eigenvalue** decompositions.

"""

class State(containers.NamedTuple):
i: int
basis: Array
tridiag: Tuple[Array, Array]
q: Array

def init(init_vec: Array) -> State:
(ncols,) = np.shape(init_vec)
if depth >= ncols or depth < 1:
raise ValueError

diag = np.zeros((depth + 1,))
offdiag = np.zeros((depth,))
basis = np.zeros((depth + 1, ncols))

return State(0, basis, (diag, offdiag), init_vec)

def apply(state: State, Av: Callable) -> State:
i, basis, (diag, offdiag), vec = state

# Re-orthogonalise against ALL basis elements before storing.
# Note: we re-orthogonalise against ALL columns of Q, not just
# the ones we have already computed. This increases the complexity
# of the whole iteration from n(n+1)/2 to n^2, but has the advantage
# that the whole computation has static bounds (thus we can JIT it all).
# Since 'Q' is padded with zeros, the numerical values are identical
# between both modes of computing.
vec, length = _normalise(vec)
vec, _ = _gram_schmidt_orthogonalise_set(vec, basis)

# I don't know why, but this re-normalisation is soooo crucial
vec, _ = _normalise(vec)
basis = basis.at[i, :].set(vec)

# When i==0, Q[i-1] is Q[-1] and again, we benefit from the fact
# that Q is initialised with zeros.
vec = Av(vec)
basis_vectors_previous = np.asarray([basis[i], basis[i - 1]])
vec, (coeff, _) = _gram_schmidt_orthogonalise_set(vec, basis_vectors_previous)
diag = diag.at[i].set(coeff)
offdiag = offdiag.at[i - 1].set(length)

return State(i + 1, basis, (diag, offdiag), vec)

def extract(state: State, /):
_, basis, (diag, offdiag), _ = state
return basis, (diag, offdiag)

return _Alg(init=init, step=apply, extract=extract, lower_upper=(0, depth + 1))


def bidiagonal_full_reortho(depth, /, matrix_shape):
"""Construct an implementation of **bidiagonalisation**.

Uses pre-allocation. Fully reorthogonalise vectors at every step.

Works for **arbitrary matrices**. No symmetry required.

Decompose a matrix into a product of orthogonal-**bidiagonal**-orthogonal matrices.
Use this algorithm for approximate **singular value** decompositions.
"""
nrows, ncols = matrix_shape
max_depth = min(nrows, ncols) - 1
if depth > max_depth or depth < 0:
msg1 = f"Depth {depth} exceeds the matrix' dimensions. "
msg2 = f"Expected: 0 <= depth <= min(nrows, ncols) - 1 = {max_depth} "
msg3 = f"for a matrix with shape {matrix_shape}."
raise ValueError(msg1 + msg2 + msg3)

class State(containers.NamedTuple):
i: int
Us: Array
Vs: Array
alphas: Array
betas: Array
beta: Array
vk: Array

def init(init_vec: Array) -> State:
nrows, ncols = matrix_shape
alphas = np.zeros((depth + 1,))
betas = np.zeros((depth + 1,))
Us = np.zeros((depth + 1, nrows))
Vs = np.zeros((depth + 1, ncols))
v0, _ = _normalise(init_vec)
return State(0, Us, Vs, alphas, betas, 0.0, v0)

def apply(state: State, Av: Callable, vA: Callable) -> State:
i, Us, Vs, alphas, betas, beta, vk = state
Vs = Vs.at[i].set(vk)
betas = betas.at[i].set(beta)

uk = Av(vk) - beta * Us[i - 1]
uk, alpha = _normalise(uk)
uk, _ = _gram_schmidt_orthogonalise_set(uk, Us) # full reorthogonalisation
uk, _ = _normalise(uk)
Us = Us.at[i].set(uk)
alphas = alphas.at[i].set(alpha)

vk = vA(uk) - alpha * vk
vk, beta = _normalise(vk)
vk, _ = _gram_schmidt_orthogonalise_set(vk, Vs) # full reorthogonalisation
vk, _ = _normalise(vk)

return State(i + 1, Us, Vs, alphas, betas, beta, vk)

def extract(state: State, /):
_, uk_all, vk_all, alphas, betas, beta, vk = state
return uk_all.T, (alphas, betas[1:]), vk_all, (beta, vk)

return _Alg(init=init, step=apply, extract=extract, lower_upper=(0, depth + 1))


def _normalise(vec):
length = linalg.vector_norm(vec)
return vec / length, length


def _gram_schmidt_orthogonalise_set(vec, vectors): # Gram-Schmidt
vec, coeffs = control_flow.scan(_gram_schmidt_orthogonalise, vec, xs=vectors)
return vec, coeffs


def _gram_schmidt_orthogonalise(vec1, vec2):
coeff = linalg.vecdot(vec1, vec2)
vec_ortho = vec1 - coeff * vec2
return vec_ortho, coeff


def svd(
v0: Array, depth: int, Av: Callable, vA: Callable, matrix_shape: Tuple[int, ...]
):
Expand All @@ -29,7 +185,7 @@ def svd(
Shape of the matrix involved in matrix-vector and vector-matrix products.
"""
# Factorise the matrix
algorithm = lanczos.bidiagonal_full_reortho(depth, matrix_shape=matrix_shape)
algorithm = bidiagonal_full_reortho(depth, matrix_shape=matrix_shape)
u, (d, e), vt, _ = decompose_fori_loop(v0, Av, vA, algorithm=algorithm)

# Compute SVD of factorisation
Expand Down Expand Up @@ -66,7 +222,7 @@ class _DecompAlg(containers.NamedTuple):
"""Decomposition algorithm type.

For example, the output of
[matfree.lanczos.tridiagonal_full_reortho(...)][matfree.lanczos.tridiagonal_full_reortho].
[matfree.decomp.tridiagonal_full_reortho(...)][matfree.decomp.tridiagonal_full_reortho].
"""


Expand Down
Loading
Loading