Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement the Chebyshev recursion for computing matrix-function-vector products (draft) #175

Merged
merged 12 commits into from
Jan 15, 2024
Merged
8 changes: 8 additions & 0 deletions matfree/backend/np.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def arange(start, /, stop=None, step=1):
return jnp.arange(start, stop, step)


def linspace(start, stop, /, *, num, endpoint=True):
return jnp.linspace(start, stop, num=num, endpoint=endpoint)


def asarray(obj, /):
return jnp.asarray(obj)

Expand Down Expand Up @@ -170,6 +174,10 @@ def nan():
return jnp.nan


def pi():
return jnp.pi


def finfo_eps(x, /):
return jnp.finfo(x).eps

Expand Down
3 changes: 3 additions & 0 deletions matfree/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ def decompose(v0, *matvec_funs, algorithm):
init, step, extract, (lower, upper) = algorithm
init_val = init(v0)

# todo: parametrized matrix-vector products
# todo: move matvec_funs into the algorithm,
# and parameters into the decompose
def body_fun(_, s):
return step(s, *matvec_funs)

Expand Down
136 changes: 136 additions & 0 deletions matfree/matfun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Implement approximations of matrix-function-vector products.

This module is experimental.

Examples
--------
>>> import jax.random
>>> import jax.numpy as jnp
>>>
>>> jnp.set_printoptions(1)
>>>
>>> M = jax.random.normal(jax.random.PRNGKey(1), shape=(10, 10))
>>> A = M.T @ M
>>> v = jax.random.normal(jax.random.PRNGKey(2), shape=(10,))
>>>
>>> # Compute a matrix-logarithm with Lanczos' algorithm
>>> matfun_vec = lanczos_spd(jnp.log, 4, lambda s: A @ s)
>>> matfun_vec(v)
Array([-4. , -2.1, -2.7, -1.9, -1.3, -3.5, -0.5, -0.1, 0.3, 1.5], dtype=float32)
"""

from matfree import decomp
from matfree.backend import containers, control_flow, func, linalg, np
from matfree.backend.typing import Array


def lanczos_spd(matfun, order, matvec, /):
"""Implement a matrix-function-vector product via Lanczos' algorithm.

This algorithm uses Lanczos' tridiagonalisation with full re-orthogonalisation.
"""
# Lanczos' algorithm
algorithm = decomp.lanczos_tridiag_full_reortho(order)

def estimate(vec, *parameters):
def matvec_p(v):
return matvec(v, *parameters)

length = linalg.vector_norm(vec)
vec /= length
basis, tridiag = decomp.decompose_fori_loop(vec, matvec_p, algorithm=algorithm)
(diag, off_diag) = tridiag

# todo: once jax supports eigh_tridiagonal(eigvals_only=False),
# use it here. Until then: an eigen-decomposition of size (order + 1)
# does not hurt too much...
diag = linalg.diagonal_matrix(diag)
offdiag1 = linalg.diagonal_matrix(off_diag, -1)
offdiag2 = linalg.diagonal_matrix(off_diag, 1)
dense_matrix = diag + offdiag1 + offdiag2
eigvals, eigvecs = linalg.eigh(dense_matrix)

fx_eigvals = func.vmap(matfun)(eigvals)
return length * (basis.T @ (eigvecs @ (fx_eigvals * eigvecs[0, :])))

return estimate


def matrix_poly_vector_product(matrix_poly_alg, /):
"""Implement a matrix-function-vector product via a polynomial expansion.

Parameters
----------
matrix_poly_alg
Which algorithm to use.
For example, the output of
[matrix_poly_chebyshev][matfree.matfun.matrix_poly_chebyshev].
"""
lower, upper, init_func, step_func, extract_func = matrix_poly_alg

def matvec(vec, *parameters):
final_state = control_flow.fori_loop(
lower=lower,
upper=upper,
body_fun=lambda _i, v: step_func(v, *parameters),
init_val=init_func(vec, *parameters),
)
return extract_func(final_state)

return matvec


def _chebyshev_nodes(n, /):
k = np.arange(n, step=1.0) + 1
return np.cos((2 * k - 1) / (2 * n) * np.pi())


def matrix_poly_chebyshev(matfun, order, matvec, /):
"""Construct an implementation of matrix-Chebyshev-polynomial interpolation.

This function assumes that the spectrum of the matrix-vector product
is contained in the interval (-1, 1), and that the matrix-function
is analytic on this interval.
If this is not the case,
transform the matrix-vector product and the matrix-function accordingly.
"""
# Construct nodes
nodes = _chebyshev_nodes(order)
fx_nodes = matfun(nodes)

class _ChebyshevState(containers.NamedTuple):
interpolation: Array
poly_coefficients: tuple[Array, Array]
poly_values: tuple[Array, Array]

def init_func(vec, *parameters):
# Initialize the scalar recursion
# (needed to compute the interpolation weights)
t2_n, t1_n = nodes, np.ones_like(nodes)
c1 = np.mean(fx_nodes * t1_n)
c2 = 2 * np.mean(fx_nodes * t2_n)

# Initialize the vector-valued recursion
# (this is where the matvec happens)
t2_x, t1_x = matvec(vec, *parameters), vec
value = c1 * t1_x + c2 * t2_x
return _ChebyshevState(value, (t2_n, t1_n), (t2_x, t1_x))

def recursion_func(val: _ChebyshevState, *parameters) -> _ChebyshevState:
value, (t2_n, t1_n), (t2_x, t1_x) = val

# Apply the next scalar recursion and
# compute the next coefficient
t2_n, t1_n = 2 * nodes * t2_n - t1_n, t2_n
c2 = 2 * np.mean(fx_nodes * t2_n)

# Apply the next matrix-vector product recursion and
# compute the next interpolation-value
t2_x, t1_x = 2 * matvec(t2_x, *parameters) - t1_x, t2_x
value += c2 * t2_x
return _ChebyshevState(value, (t2_n, t1_n), (t2_x, t1_x))

def extract_func(val: _ChebyshevState):
return val.interpolation

return 0, order - 1, init_func, recursion_func, extract_func
2 changes: 0 additions & 2 deletions matfree/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

def symmetric_matrix_from_eigenvalues(eigvals, /):
"""Generate a symmetric matrix with prescribed eigenvalues."""
assert np.array_min(eigvals) > 0
(n,) = eigvals.shape

# Need _some_ matrix to start with
Expand All @@ -25,7 +24,6 @@ def symmetric_matrix_from_eigenvalues(eigvals, /):

def asymmetric_matrix_from_singular_values(vals, /, nrows, ncols):
"""Generate an asymmetric matrix with specific singular values."""
assert np.array_min(vals) > 0
A = np.reshape(np.arange(1.0, nrows * ncols + 1.0), (nrows, ncols))
A /= nrows * ncols
U, S, Vt = linalg.svd(A, full_matrices=False)
Expand Down
33 changes: 33 additions & 0 deletions tests/test_matfun/test_lanczos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Test matrix-function-vector products via Lanczos' algorithm."""
from matfree import matfun, test_util
from matfree.backend import linalg, np, prng


def test_lanczos_via_matfun_vector_product(n=11):
"""Test matrix-function-vector products via Lanczos' algorithm."""
# Create a test-problem: matvec, matrix function,
# vector, and parameters (a matrix).

def matvec(x, p):
return p @ x

# todo: write a test for matfun=np.inv,
# because this application seems to be brittle
def fun(x):
return np.sin(x)

v = prng.normal(prng.prng_key(2), shape=(n,))

eigvals = np.linspace(0.01, 0.99, num=n)
matrix = test_util.symmetric_matrix_from_eigenvalues(eigvals)

# Compute the solution
eigvals, eigvecs = linalg.eigh(matrix)
log_matrix = eigvecs @ linalg.diagonal(fun(eigvals)) @ eigvecs.T
expected = log_matrix @ v

# Compute the matrix-function vector product
order = 6
matfun_vec = matfun.lanczos_spd(fun, order, matvec)
received = matfun_vec(v, matrix)
assert np.allclose(expected, received, atol=1e-6)
36 changes: 36 additions & 0 deletions tests/test_matfun/test_matrix_poly_chebyshev.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Test matrix-polynomial-vector algorithms via Chebyshev's recursion."""
from matfree import matfun, test_util
from matfree.backend import linalg, np, prng


def test_matrix_poly_chebyshev(n=12):
"""Test matrix-polynomial-vector algorithms via Chebyshev's recursion."""
# Create a test-problem: matvec, matrix function,
# vector, and parameters (a matrix).

def matvec(x, p):
return p @ x

# todo: write a test for matfun=np.inv,
# because this application seems to be brittle
def fun(x):
return np.sin(x)

v = prng.normal(prng.prng_key(2), shape=(n,))

eigvals = np.linspace(-1 + 0.01, 1 - 0.01, num=n)
matrix = test_util.symmetric_matrix_from_eigenvalues(eigvals)

# Compute the solution
eigvals, eigvecs = linalg.eigh(matrix)
log_matrix = eigvecs @ linalg.diagonal(fun(eigvals)) @ eigvecs.T
expected = log_matrix @ v

# Create an implementation of the Chebyshev-algorithm
order = 6
algorithm = matfun.matrix_poly_chebyshev(fun, order, matvec)

# Compute the matrix-function vector product
matfun_vec = matfun.matrix_poly_vector_product(algorithm)
received = matfun_vec(v, matrix)
assert np.allclose(expected, received, rtol=1e-4)
Loading