Skip to content

Commit

Permalink
Implement validation of a unit-2-norm inside Lanczos-style decomposit…
Browse files Browse the repository at this point in the history
…ions
  • Loading branch information
pnkraemer committed Jan 8, 2024
1 parent b3bdb65 commit 7fe41b5
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 3 deletions.
8 changes: 8 additions & 0 deletions matfree/backend/np.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def any(x, /): # noqa: A001
return jnp.any(x)


def all(x, /): # noqa: A001
return jnp.all(x)


def allclose(x1, x2, /, *, rtol=1e-5, atol=1e-8):
return jnp.allclose(x1, x2, rtol=rtol, atol=atol)

Expand Down Expand Up @@ -158,6 +162,10 @@ def nan():
return jnp.nan


def finfo_eps(x, /):
return jnp.finfo(x).eps


# Others


Expand Down
28 changes: 25 additions & 3 deletions matfree/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class _Alg(containers.NamedTuple):
"""Range of the for-loop used to decompose a matrix."""


def lanczos_tridiag_full_reortho(depth, /) -> AlgorithmType:
def lanczos_tridiag_full_reortho(depth, /, validate_unit_2_norm=False) -> AlgorithmType:
"""Construct an implementation of **tridiagonalisation**.
Uses pre-allocation. Fully reorthogonalise vectors at every step.
Expand All @@ -50,6 +50,9 @@ def init(init_vec: Array) -> State:
if depth >= ncols or depth < 1:
raise ValueError

if validate_unit_2_norm:
init_vec = _validate_unit_2_norm(init_vec)

diag = np.zeros((depth + 1,))
offdiag = np.zeros((depth,))
basis = np.zeros((depth + 1, ncols))
Expand Down Expand Up @@ -84,13 +87,16 @@ def apply(state: State, Av: Callable) -> State:
return State(i + 1, basis, (diag, offdiag), vec)

def extract(state: State, /):
_, basis, (diag, offdiag), _ = state
# todo: return final output "_ignored"
_, basis, (diag, offdiag), _ignored = state
return basis, (diag, offdiag)

return _Alg(init=init, step=apply, extract=extract, lower_upper=(0, depth + 1))


def lanczos_bidiag_full_reortho(depth, /, matrix_shape) -> AlgorithmType:
def lanczos_bidiag_full_reortho(
depth, /, matrix_shape, validate_unit_2_norm=False
) -> AlgorithmType:
"""Construct an implementation of **bidiagonalisation**.
Uses pre-allocation. Fully reorthogonalise vectors at every step.
Expand Down Expand Up @@ -118,6 +124,9 @@ class State(containers.NamedTuple):
vk: Array

def init(init_vec: Array) -> State:
if validate_unit_2_norm:
init_vec = _validate_unit_2_norm(init_vec)

nrows, ncols = matrix_shape
alphas = np.zeros((depth + 1,))
betas = np.zeros((depth + 1,))
Expand Down Expand Up @@ -152,6 +161,19 @@ def extract(state: State, /):
return _Alg(init=init, step=apply, extract=extract, lower_upper=(0, depth + 1))


def _validate_unit_2_norm(v, /):
# Lanczos assumes a unit-2-norm vector as an input
# We cannot raise an error based on values of the init_vec,
# but we can make it obvious that the result is unusable.
is_not_normalized = np.abs(linalg.vector_norm(v) - 1.0) > 10 * np.finfo_eps(v.dtype)
return control_flow.cond(
is_not_normalized,
lambda s: np.nan() * np.ones_like(s),
lambda s: s,
v,
)


def _normalise(vec):
length = linalg.vector_norm(vec)
return vec / length, length
Expand Down
2 changes: 2 additions & 0 deletions matfree/slq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def integrand_slq_spd(matfun, order, matvec, /):

def quadform(v0, *parameters):
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)
v0_flat /= linalg.vector_norm(v0_flat)

def matvec_flat(v_flat):
v = v_unflatten(v_flat)
Expand Down Expand Up @@ -85,6 +86,7 @@ def integrand_slq_product(matfun, depth, matvec, vecmat, /):

def quadform(v0, *parameters):
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)
v0_flat /= linalg.vector_norm(v0_flat)

def matvec_flat(v_flat):
v = v_unflatten(v_flat)
Expand Down
31 changes: 31 additions & 0 deletions tests/test_decomp/test_bidiagonal_full_reortho.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def Av(v):
def vA(v):
return v @ A

v0 /= linalg.vector_norm(v0)
Us, Bs, Vs, (b, v) = decomp.decompose_fori_loop(v0, Av, vA, algorithm=alg)
(d_m, e_m) = Bs

Expand Down Expand Up @@ -111,3 +112,33 @@ def vA(v):
assert np.shape(e_m) == (0,)
assert np.shape(b) == ()
assert np.shape(v) == (ncols,)


@testing.parametrize("nrows", [15])
@testing.parametrize("ncols", [3])
@testing.parametrize("num_significant_singular_vals", [3])
@testing.parametrize("order", [2])
def test_validate_unit_norm(A, order):
"""Test that the outputs are NaN if the input is not normalized."""
nrows, ncols = np.shape(A)
algorithm = decomp.lanczos_bidiag_full_reortho(
order, matrix_shape=np.shape(A), validate_unit_2_norm=True
)
key = prng.prng_key(1)

# Not normalized!
v0 = prng.normal(key, shape=(ncols,)) + 1.0

def Av(v):
return A @ v

def vA(v):
return v @ A

Us, (d_m, e_m), Vs, (b, v) = decomp.decompose_fori_loop(
v0, Av, vA, algorithm=algorithm
)

# Since v0 is not normalized, all inputs are NaN
for x in (Us, d_m, e_m, Vs, b, v):
assert np.all(np.isnan(x))
1 change: 1 addition & 0 deletions tests/test_decomp/test_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def vA(v):
return v @ A

v0 = np.ones((ncols,))
v0 /= linalg.vector_norm(v0)
U, S, Vt = decomp.svd_approx(v0, depth, Av, vA, matrix_shape=np.shape(A))
U_, S_, Vt_ = linalg.svd(A, full_matrices=False)

Expand Down
21 changes: 21 additions & 0 deletions tests/test_decomp/test_tridiagonal_full_reortho.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def test_max_order(A):
order = n - 1
key = prng.prng_key(1)
v0 = prng.normal(key, shape=(n,))
v0 /= linalg.vector_norm(v0)
alg = decomp.lanczos_tridiag_full_reortho(order)
Q, (d_m, e_m) = decomp.decompose_fori_loop(v0, lambda v: A @ v, algorithm=alg)

Expand Down Expand Up @@ -63,6 +64,7 @@ def test_identity(A, order):
n, _ = np.shape(A)
key = prng.prng_key(1)
v0 = prng.normal(key, shape=(n,))
v0 /= linalg.vector_norm(v0)
alg = decomp.lanczos_tridiag_full_reortho(order)
Q, tridiag = decomp.decompose_fori_loop(v0, lambda v: A @ v, algorithm=alg)
(d_m, e_m) = tridiag
Expand Down Expand Up @@ -92,3 +94,22 @@ def _sym_tridiagonal_dense(d, e):
offdiag1 = linalg.diagonal_matrix(e, 1)
offdiag2 = linalg.diagonal_matrix(e, -1)
return diag + offdiag1 + offdiag2


@testing.parametrize("n", [50])
@testing.parametrize("num_significant_eigvals", [4])
@testing.parametrize("order", [6]) # ~1.5 * num_significant_eigvals
def test_validate_unit_norm(A, order):
"""Test that the outputs are NaN if the input is not normalized."""
n, _ = np.shape(A)
key = prng.prng_key(1)

# Not normalized!
v0 = prng.normal(key, shape=(n,)) + 1.0

alg = decomp.lanczos_tridiag_full_reortho(order, validate_unit_2_norm=True)
Q, (d_m, e_m) = decomp.decompose_fori_loop(v0, lambda v: A @ v, algorithm=alg)

# Since v0 is not normalized, all inputs are NaN
for x in (Q, d_m, e_m):
assert np.all(np.isnan(x))

0 comments on commit 7fe41b5

Please sign in to comment.