Skip to content

Commit

Permalink
Reverse the eigenvalue-range modifications
Browse files Browse the repository at this point in the history
  • Loading branch information
pnkraemer committed Jan 12, 2024
1 parent a7c6979 commit 8220e04
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
7 changes: 6 additions & 1 deletion matfree/matfun.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,12 @@ def _chebyshev_nodes(n, /):


def matrix_poly_chebyshev(matfun, order, matvec, /):
"""Construct an implementation of matrix-Chebyshev-polynomial interpolation."""
"""Construct an implementation of matrix-Chebyshev-polynomial interpolation.
This function assumes that the spectrum of the matrix-vector product
is contained in the interval (-1, 1), and that the matrix-function
is analytic on this interval.
"""
# Construct nodes
nodes = _chebyshev_nodes(order)
fx_nodes = matfun(nodes)
Expand Down
10 changes: 4 additions & 6 deletions tests/test_matfun/test_matrix_poly_chebyshev.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Test matrix-polynomial-vector algorithms via Chebyshev's recursion."""
from matfree import matfun, test_util
from matfree.backend import linalg, np, prng, testing
from matfree.backend import linalg, np, prng


@testing.parametrize("eigvals_range", [(-1, 1), (3, 4)])
def test_matrix_poly_chebyshev(eigvals_range, n=4):
def test_matrix_poly_chebyshev(n=12):
"""Test matrix-polynomial-vector algorithms via Chebyshev's recursion."""
# Create a test-problem: matvec, matrix function,
# vector, and parameters (a matrix).
Expand All @@ -19,8 +18,7 @@ def fun(x):

v = prng.normal(prng.prng_key(2), shape=(n,))

eigvals = np.linspace(0, 1, num=n)
eigvals = eigvals_range[0] + eigvals * (eigvals_range[1] - eigvals_range[0])
eigvals = np.linspace(-1 + 0.01, 1 - 0.01, num=n)
matrix = test_util.symmetric_matrix_from_eigenvalues(eigvals)

# Compute the solution
Expand All @@ -35,4 +33,4 @@ def fun(x):
# Compute the matrix-function vector product
matfun_vec = matfun.matrix_poly_vector_product(algorithm)
received = matfun_vec(v, matrix)
assert np.allclose(expected, received)
assert np.allclose(expected, received, rtol=1e-4)

0 comments on commit 8220e04

Please sign in to comment.