Skip to content

Commit

Permalink
Merge montecarlo/hutchinson and decomp/lanczos modules (#161)
Browse files Browse the repository at this point in the history
* Merged montecarlo and hutchinson file

* All tests pass again

* Do not show source in docs

* Merged decompositions and Lanczos
  • Loading branch information
pnkraemer authored Nov 21, 2023
1 parent 08cabcb commit 2210962
Show file tree
Hide file tree
Showing 39 changed files with 424 additions and 447 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ clean:
rm -rf *.egg-info
rm -rf dist site build htmlcov
rm -rf *.ipynb_checkpoints
rm matfree/_version.py

doc:
mkdocs build
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Import matfree and JAX, and set up a test problem.
```python
>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutchinson, montecarlo, slq
>>> from matfree import hutchinson, slq

>>> A = jnp.reshape(jnp.arange(12.0), (6, 2))
>>>
Expand All @@ -65,7 +65,7 @@ Estimate the trace of the matrix:

```python
>>> key = jax.random.PRNGKey(1)
>>> normal = montecarlo.normal(shape=(2,))
>>> normal = hutchinson.normal(shape=(2,))
>>> trace = hutchinson.trace(matvec, key=key, sample_fun=normal)
>>>
>>> print(jnp.round(trace))
Expand Down
3 changes: 0 additions & 3 deletions docs/api/lanczos.md

This file was deleted.

2 changes: 0 additions & 2 deletions docs/api/montecarlo.md

This file was deleted.

4 changes: 2 additions & 2 deletions docs/benchmarks/control_variates.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Runtime: ~10 seconds.
"""

from matfree import benchmark_util, hutchinson, montecarlo
from matfree import benchmark_util, hutchinson
from matfree.backend import func, linalg, np, plt, prng, progressbar


Expand All @@ -22,7 +22,7 @@ def f(x):
_, jvp = func.linearize(f, x0)
J = func.jacfwd(f)(x0)
trace = linalg.trace(J)
sample_fun = montecarlo.normal(shape=(n,), dtype=float)
sample_fun = hutchinson.normal(shape=(n,), dtype=float)

return (jvp, trace, J), (key, sample_fun)

Expand Down
4 changes: 2 additions & 2 deletions docs/benchmarks/jacobian_squared.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""What is the fastest way of computing trace(A^5)."""
from matfree import benchmark_util, hutchinson, montecarlo, slq
from matfree import benchmark_util, hutchinson, slq
from matfree.backend import func, linalg, np, plt, prng
from matfree.backend.progressbar import progressbar

Expand All @@ -20,7 +20,7 @@ def f(x):
J = func.jacfwd(f)(x0)
A = J @ J @ J @ J
trace = linalg.trace(A)
sample_fun = montecarlo.normal(shape=(n,), dtype=float)
sample_fun = hutchinson.normal(shape=(n,), dtype=float)

def Av(v):
return jvp(jvp(jvp(jvp(v))))
Expand Down
4 changes: 2 additions & 2 deletions docs/control_variates.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ Imports:
```python
>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutchinson, montecarlo, slq
>>> from matfree import hutchinson, slq

>>> a = jnp.reshape(jnp.arange(12.0), (6, 2))
>>> key = jax.random.PRNGKey(1)

>>> matvec = lambda x: a.T @ (a @ x)
>>> sample_fun = montecarlo.normal(shape=(2,))
>>> sample_fun = hutchinson.normal(shape=(2,))

```

Expand Down
8 changes: 4 additions & 4 deletions docs/higher_moments.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
```python
>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutchinson, montecarlo, slq
>>> from matfree import hutchinson, slq

>>> a = jnp.reshape(jnp.arange(12.0), (6, 2))
>>> key = jax.random.PRNGKey(1)

>>> mvp = lambda x: a.T @ (a @ x)
>>> sample_fun = montecarlo.normal(shape=(2,))
>>> sample_fun = hutchinson.normal(shape=(2,))

```

Expand All @@ -21,7 +21,7 @@ Compute them as such

```python
>>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36
>>> normal = montecarlo.normal(shape=(6,))
>>> normal = hutchinson.normal(shape=(6,))
>>> mvp = lambda x: a.T @ (a @ x) + x
>>> first, second = hutchinson.trace_moments(mvp, key=key, sample_fun=normal)
>>> print(jnp.round(first, 1))
Expand Down Expand Up @@ -53,7 +53,7 @@ Implement this as follows:

```python
>>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36
>>> sample_fun = montecarlo.normal(shape=(6,))
>>> sample_fun = hutchinson.normal(shape=(6,))
>>> num_samples = 10_000
>>> mvp = lambda x: a.T @ (a @ x) + x
>>> first, second = hutchinson.trace_moments(
Expand Down
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Import matfree and JAX, and set up a test problem.
```python
>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutchinson, montecarlo, slq
>>> from matfree import hutchinson, slq

>>> A = jnp.reshape(jnp.arange(12.0), (6, 2))
>>>
Expand All @@ -65,7 +65,7 @@ Estimate the trace of the matrix:

```python
>>> key = jax.random.PRNGKey(1)
>>> normal = montecarlo.normal(shape=(2,))
>>> normal = hutchinson.normal(shape=(2,))
>>> trace = hutchinson.trace(matvec, key=key, sample_fun=normal)
>>>
>>> print(jnp.round(trace))
Expand Down
8 changes: 4 additions & 4 deletions docs/log_determinants.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,21 @@ Imports:
```python
>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutchinson, montecarlo, slq
>>> from matfree import hutchinson, slq

>>> a = jnp.reshape(jnp.arange(12.0), (6, 2))
>>> key = jax.random.PRNGKey(1)

>>> matvec = lambda x: a.T @ (a @ x)
>>> sample_fun = montecarlo.normal(shape=(2,))
>>> sample_fun = hutchinson.normal(shape=(2,))

```


Estimate log-determinants as such:
```python
>>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36
>>> sample_fun = montecarlo.normal(shape=(6,))
>>> sample_fun = hutchinson.normal(shape=(6,))
>>> matvec = lambda x: a.T @ (a @ x) + x
>>> order = 3
>>> logdet = slq.logdet_spd(order, matvec, key=key, sample_fun=sample_fun)
Expand All @@ -37,7 +37,7 @@ on arithmetic with $B$; no need to assemble $M$:

```python
>>> a = jnp.reshape(jnp.arange(36.0), (6, 6)) / 36 + jnp.eye(6)
>>> sample_fun = montecarlo.normal(shape=(6,))
>>> sample_fun = hutchinson.normal(shape=(6,))
>>> matvec = lambda x: (a @ x)
>>> vecmat = lambda x: (a.T @ x)
>>> order = 3
Expand Down
4 changes: 2 additions & 2 deletions docs/pytree_logdeterminants.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Imports:
>>> import jax.flatten_util # this is important!
>>> import jax.numpy as jnp
>>>
>>> from matfree import slq, montecarlo
>>> from matfree import slq, hutchinson

```
Create a test-problem: a function that maps a pytree (dict) to a pytree (tuple).
Expand Down Expand Up @@ -70,7 +70,7 @@ Now, we can compute the log-determinant with the flattened inputs as usual:
```python
>>> # Compute the log-determinant
>>> key = jax.random.PRNGKey(seed=1)
>>> sample_fun = montecarlo.normal(shape=f0_flat.shape)
>>> sample_fun = hutchinson.normal(shape=f0_flat.shape)
>>> order = 3
>>> logdet = slq.logdet_spd(order, matvec, key=key, sample_fun=sample_fun)

Expand Down
4 changes: 2 additions & 2 deletions docs/vector_calculus.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Here is how we can implement divergences and Laplacians without forming full Jac
```python
>>> import jax
>>> import jax.numpy as jnp
>>> from matfree import hutchinson, montecarlo
>>> from matfree import hutchinson

```

Expand Down Expand Up @@ -85,7 +85,7 @@ For large-scale problems, it may be the only way of computing Laplacians reliabl
```python
>>> laplacian_dense = divergence_dense(gradient)
>>>
>>> normal = montecarlo.normal(shape=(3,))
>>> normal = hutchinson.normal(shape=(3,))
>>> key = jax.random.PRNGKey(1)
>>> laplacian_matfree = divergence_matfree(gradient, key=key, sample_fun=normal)
>>>
Expand Down
164 changes: 160 additions & 4 deletions matfree/decomp.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,166 @@
"""Matrix decomposition algorithms."""

from matfree import lanczos
from matfree.backend import containers, control_flow, linalg
from matfree.backend import containers, control_flow, linalg, np
from matfree.backend.typing import Array, Callable, Tuple


class _Alg(containers.NamedTuple):
"""Matrix decomposition algorithm."""

init: Callable
"""Initialise the state of the algorithm. Usually, this involves pre-allocation."""

step: Callable
"""Compute the next iteration."""

extract: Callable
"""Extract the solution from the state of the algorithm."""

lower_upper: Tuple[int, int]
"""Range of the for-loop used to decompose a matrix."""


def tridiagonal_full_reortho(depth, /):
"""Construct an implementation of **tridiagonalisation**.
Uses pre-allocation. Fully reorthogonalise vectors at every step.
This algorithm assumes a **symmetric matrix**.
Decompose a matrix into a product of orthogonal-**tridiagonal**-orthogonal matrices.
Use this algorithm for approximate **eigenvalue** decompositions.
"""

class State(containers.NamedTuple):
i: int
basis: Array
tridiag: Tuple[Array, Array]
q: Array

def init(init_vec: Array) -> State:
(ncols,) = np.shape(init_vec)
if depth >= ncols or depth < 1:
raise ValueError

diag = np.zeros((depth + 1,))
offdiag = np.zeros((depth,))
basis = np.zeros((depth + 1, ncols))

return State(0, basis, (diag, offdiag), init_vec)

def apply(state: State, Av: Callable) -> State:
i, basis, (diag, offdiag), vec = state

# Re-orthogonalise against ALL basis elements before storing.
# Note: we re-orthogonalise against ALL columns of Q, not just
# the ones we have already computed. This increases the complexity
# of the whole iteration from n(n+1)/2 to n^2, but has the advantage
# that the whole computation has static bounds (thus we can JIT it all).
# Since 'Q' is padded with zeros, the numerical values are identical
# between both modes of computing.
vec, length = _normalise(vec)
vec, _ = _gram_schmidt_orthogonalise_set(vec, basis)

# I don't know why, but this re-normalisation is soooo crucial
vec, _ = _normalise(vec)
basis = basis.at[i, :].set(vec)

# When i==0, Q[i-1] is Q[-1] and again, we benefit from the fact
# that Q is initialised with zeros.
vec = Av(vec)
basis_vectors_previous = np.asarray([basis[i], basis[i - 1]])
vec, (coeff, _) = _gram_schmidt_orthogonalise_set(vec, basis_vectors_previous)
diag = diag.at[i].set(coeff)
offdiag = offdiag.at[i - 1].set(length)

return State(i + 1, basis, (diag, offdiag), vec)

def extract(state: State, /):
_, basis, (diag, offdiag), _ = state
return basis, (diag, offdiag)

return _Alg(init=init, step=apply, extract=extract, lower_upper=(0, depth + 1))


def bidiagonal_full_reortho(depth, /, matrix_shape):
"""Construct an implementation of **bidiagonalisation**.
Uses pre-allocation. Fully reorthogonalise vectors at every step.
Works for **arbitrary matrices**. No symmetry required.
Decompose a matrix into a product of orthogonal-**bidiagonal**-orthogonal matrices.
Use this algorithm for approximate **singular value** decompositions.
"""
nrows, ncols = matrix_shape
max_depth = min(nrows, ncols) - 1
if depth > max_depth or depth < 0:
msg1 = f"Depth {depth} exceeds the matrix' dimensions. "
msg2 = f"Expected: 0 <= depth <= min(nrows, ncols) - 1 = {max_depth} "
msg3 = f"for a matrix with shape {matrix_shape}."
raise ValueError(msg1 + msg2 + msg3)

class State(containers.NamedTuple):
i: int
Us: Array
Vs: Array
alphas: Array
betas: Array
beta: Array
vk: Array

def init(init_vec: Array) -> State:
nrows, ncols = matrix_shape
alphas = np.zeros((depth + 1,))
betas = np.zeros((depth + 1,))
Us = np.zeros((depth + 1, nrows))
Vs = np.zeros((depth + 1, ncols))
v0, _ = _normalise(init_vec)
return State(0, Us, Vs, alphas, betas, 0.0, v0)

def apply(state: State, Av: Callable, vA: Callable) -> State:
i, Us, Vs, alphas, betas, beta, vk = state
Vs = Vs.at[i].set(vk)
betas = betas.at[i].set(beta)

uk = Av(vk) - beta * Us[i - 1]
uk, alpha = _normalise(uk)
uk, _ = _gram_schmidt_orthogonalise_set(uk, Us) # full reorthogonalisation
uk, _ = _normalise(uk)
Us = Us.at[i].set(uk)
alphas = alphas.at[i].set(alpha)

vk = vA(uk) - alpha * vk
vk, beta = _normalise(vk)
vk, _ = _gram_schmidt_orthogonalise_set(vk, Vs) # full reorthogonalisation
vk, _ = _normalise(vk)

return State(i + 1, Us, Vs, alphas, betas, beta, vk)

def extract(state: State, /):
_, uk_all, vk_all, alphas, betas, beta, vk = state
return uk_all.T, (alphas, betas[1:]), vk_all, (beta, vk)

return _Alg(init=init, step=apply, extract=extract, lower_upper=(0, depth + 1))


def _normalise(vec):
length = linalg.vector_norm(vec)
return vec / length, length


def _gram_schmidt_orthogonalise_set(vec, vectors): # Gram-Schmidt
vec, coeffs = control_flow.scan(_gram_schmidt_orthogonalise, vec, xs=vectors)
return vec, coeffs


def _gram_schmidt_orthogonalise(vec1, vec2):
coeff = linalg.vecdot(vec1, vec2)
vec_ortho = vec1 - coeff * vec2
return vec_ortho, coeff


def svd(
v0: Array, depth: int, Av: Callable, vA: Callable, matrix_shape: Tuple[int, ...]
):
Expand All @@ -29,7 +185,7 @@ def svd(
Shape of the matrix involved in matrix-vector and vector-matrix products.
"""
# Factorise the matrix
algorithm = lanczos.bidiagonal_full_reortho(depth, matrix_shape=matrix_shape)
algorithm = bidiagonal_full_reortho(depth, matrix_shape=matrix_shape)
u, (d, e), vt, _ = decompose_fori_loop(v0, Av, vA, algorithm=algorithm)

# Compute SVD of factorisation
Expand Down Expand Up @@ -66,7 +222,7 @@ class _DecompAlg(containers.NamedTuple):
"""Decomposition algorithm type.
For example, the output of
[matfree.lanczos.tridiagonal_full_reortho(...)][matfree.lanczos.tridiagonal_full_reortho].
[matfree.decomp.tridiagonal_full_reortho(...)][matfree.decomp.tridiagonal_full_reortho].
"""


Expand Down
Loading

0 comments on commit 2210962

Please sign in to comment.