Skip to content

Commit

Permalink
Rename matfun.py to polynomial.py and merge algorithm and decompositi…
Browse files Browse the repository at this point in the history
…on into a single function (#179)
  • Loading branch information
pnkraemer authored Jan 15, 2024
1 parent 48bb1be commit dfca7e5
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 46 deletions.
69 changes: 29 additions & 40 deletions matfree/matfun.py → matfree/polynomial.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,15 @@
"""Approximate matrix-function-vector products with polynomial expansions.
This module is experimental.
"""
"""Approximate matrix-function-vector products with polynomial expansions."""

from matfree.backend import containers, control_flow, np
from matfree.backend.typing import Array


def matrix_poly_vector_product(matrix_poly_alg, /):
"""Implement a matrix-function-vector product via a polynomial expansion.
Parameters
----------
matrix_poly_alg
Which algorithm to use.
For example, the output of
[matrix_poly_chebyshev][matfree.matfun.matrix_poly_chebyshev].
"""
lower, upper, init_func, step_func, extract_func = matrix_poly_alg

def matvec(vec, *parameters):
final_state = control_flow.fori_loop(
lower=lower,
upper=upper,
body_fun=lambda _i, v: step_func(v, *parameters),
init_val=init_func(vec, *parameters),
)
return extract_func(final_state)

return matvec


def _chebyshev_nodes(n, /):
k = np.arange(n, step=1.0) + 1
return np.cos((2 * k - 1) / (2 * n) * np.pi())

def funm_vector_product_chebyshev(matfun, order, matvec, /):
"""Compute a matrix-function-vector product via Chebyshev's algorithm.
def matrix_poly_chebyshev(matfun, order, matvec, /):
"""Construct an implementation of matrix-Chebyshev-polynomial interpolation.
This function assumes that the spectrum of the matrix-vector product
is contained in the interval (-1, 1), and that the matrix-function
is analytic on this interval.
If this is not the case,
This function assumes that the **spectrum of the matrix-vector product
is contained in the interval (-1, 1)**, and that the **matrix-function
is analytic on this interval**. If this is not the case,
transform the matrix-vector product and the matrix-function accordingly.
"""
# Construct nodes
Expand Down Expand Up @@ -84,4 +51,26 @@ def recursion_func(val: _ChebyshevState, *parameters) -> _ChebyshevState:
def extract_func(val: _ChebyshevState):
return val.interpolation

return 0, order - 1, init_func, recursion_func, extract_func
alg = (0, order - 1), init_func, recursion_func, extract_func
return _funm_vector_product_polyexpand(alg)


def _chebyshev_nodes(n, /):
k = np.arange(n, step=1.0) + 1
return np.cos((2 * k - 1) / (2 * n) * np.pi())


def _funm_vector_product_polyexpand(matrix_poly_alg, /):
"""Implement a matrix-function-vector product via a polynomial expansion."""
(lower, upper), init_func, step_func, extract_func = matrix_poly_alg

def matvec(vec, *parameters):
final_state = control_flow.fori_loop(
lower=lower,
upper=upper,
body_fun=lambda _i, v: step_func(v, *parameters),
init_val=init_func(vec, *parameters),
)
return extract_func(final_state)

return matvec
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
"""Test matrix-polynomial-vector algorithms via Chebyshev's recursion."""
from matfree import matfun, test_util
from matfree import polynomial, test_util
from matfree.backend import linalg, np, prng


def test_matrix_poly_chebyshev(n=12):
def test_funm_vector_product_chebyshev(n=12):
"""Test matrix-polynomial-vector algorithms via Chebyshev's recursion."""
# Create a test-problem: matvec, matrix function,
# vector, and parameters (a matrix).

def matvec(x, p):
return p @ x

# todo: write a test for matfun=np.inv,
# because this application seems to be brittle
def fun(x):
return np.sin(x)

Expand All @@ -28,9 +26,8 @@ def fun(x):

# Create an implementation of the Chebyshev-algorithm
order = 6
algorithm = matfun.matrix_poly_chebyshev(fun, order, matvec)
matfun_vec = polynomial.funm_vector_product_chebyshev(fun, order, matvec)

# Compute the matrix-function vector product
matfun_vec = matfun.matrix_poly_vector_product(algorithm)
received = matfun_vec(v, matrix)
assert np.allclose(expected, received, rtol=1e-4)

0 comments on commit dfca7e5

Please sign in to comment.