diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dddcd4b..d66d425 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ default_language_version: python: python3 repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: check-yaml - id: end-of-file-fixer @@ -11,7 +11,7 @@ repos: - id: check-merge-conflict - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.6.4 + rev: v0.8.6 hooks: # Run the linter. - id: ruff @@ -19,7 +19,7 @@ repos: # Run the formatter. - id: ruff-format - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.11.2 + rev: v1.14.1 hooks: - id: mypy args: diff --git a/matfree/decomp.py b/matfree/decomp.py index e2c1796..8981ec9 100644 --- a/matfree/decomp.py +++ b/matfree/decomp.py @@ -583,6 +583,8 @@ def bidiag(depth: int, /, materialize: bool = True): Decompose a matrix into a product of orthogonal-**bidiagonal**-orthogonal matrices. Use this algorithm for approximate **singular value** decompositions. + Internally, Matfree uses JAX to turn matrix-vector- into vector-matrix-products. + ??? note "A note about differentiability" Unlike [tridiag_sym][matfree.decomp.tridiag_sym] or [hessenberg][matfree.decomp.hessenberg], this function's reverse-mode @@ -592,7 +594,7 @@ def bidiag(depth: int, /, materialize: bool = True): """ - def estimate(Av: Callable, vA: Callable, v0, *parameters): + def estimate(Av: Callable, v0, *parameters): # Infer the size of A from v0 (ncols,) = np.shape(v0) w0_like = func.eval_shape(Av, v0) @@ -610,7 +612,7 @@ def estimate(Av: Callable, vA: Callable, v0, *parameters): init_val = init(v0_norm, nrows=nrows, ncols=ncols) def body_fun(_, s): - return step(Av, vA, s, *parameters) + return step(Av, s, *parameters) result = control_flow.fori_loop( 0, depth + 1, body_fun=body_fun, init_val=init_val @@ -640,18 +642,21 @@ def init(init_vec: Array, *, nrows, ncols) -> State: v0, _ = _normalise(init_vec) return State(0, Us, Vs, alphas, betas, np.zeros(()), v0) - def step(Av, vA, state: State, *parameters) -> State: + def step(Av, state: State, *parameters) -> State: i, Us, Vs, alphas, betas, beta, vk = state Vs = Vs.at[i].set(vk) betas = betas.at[i].set(beta) - uk = Av(vk, *parameters) - beta * Us[i - 1] + # Use jax.vjp to evaluate the vector-matrix product + Av_eval, vA = func.vjp(lambda v: Av(v, *parameters), vk) + uk = Av_eval - beta * Us[i - 1] uk, alpha = _normalise(uk) uk, *_ = _gram_schmidt_classical(uk, Us) # full reorthogonalisation Us = Us.at[i].set(uk) alphas = alphas.at[i].set(alpha) - vk = vA(uk, *parameters) - alpha * vk + (vA_eval,) = vA(uk) + vk = vA_eval - alpha * vk vk, beta = _normalise(vk) vk, *_ = _gram_schmidt_classical(vk, Vs) # full reorthogonalisation diff --git a/matfree/eig.py b/matfree/eig.py index be3b313..7c6da75 100644 --- a/matfree/eig.py +++ b/matfree/eig.py @@ -6,7 +6,7 @@ # todo: why does this function not return a callable? -def svd_partial(v0: Array, depth: int, Av: Callable, vA: Callable): +def svd_partial(v0: Array, depth: int, Av: Callable): """Partial singular value decomposition. Combines bidiagonalisation with full reorthogonalisation @@ -27,7 +27,7 @@ def svd_partial(v0: Array, depth: int, Av: Callable, vA: Callable): """ # Factorise the matrix algorithm = decomp.bidiag(depth, materialize=True) - (u, v), B, *_ = algorithm(Av, vA, v0) + (u, v), B, *_ = algorithm(Av, v0) # Compute SVD of factorisation U, S, Vt = linalg.svd(B, full_matrices=False) diff --git a/matfree/funm.py b/matfree/funm.py index 1ab67fe..5c457b2 100644 --- a/matfree/funm.py +++ b/matfree/funm.py @@ -259,8 +259,7 @@ def integrand_funm_product(dense_funm, algorithm, /): Here, "product" refers to $X = A^\top A$. """ - def quadform(matvecs, v0, *parameters): - matvec, vecmat = matvecs + def quadform(matvec, v0, *parameters): v0_flat, v_unflatten = tree_util.ravel_pytree(v0) length = linalg.vector_norm(v0_flat) v0_flat /= length @@ -268,21 +267,13 @@ def quadform(matvecs, v0, *parameters): def matvec_flat(v_flat, *p): v = v_unflatten(v_flat) Av = matvec(v, *p) - flat, unflatten = tree_util.ravel_pytree(Av) - return flat, tree_util.partial_pytree(unflatten) - - w0_flat, w_unflatten = func.eval_shape(matvec_flat, v0_flat) - - def vecmat_flat(w_flat): - w = w_unflatten(w_flat) - wA = vecmat(w, *parameters) - return tree_util.ravel_pytree(wA)[0] + flat, _unflatten = tree_util.ravel_pytree(Av) + return flat # Decompose into orthogonal-bidiag-orthogonal - matvec_flat_p = lambda v: matvec_flat(v)[0] # noqa: E731 - output = algorithm(matvec_flat_p, vecmat_flat, v0_flat, *parameters) - _, B, *_ = output + _, B, *_ = algorithm(matvec_flat, v0_flat, *parameters) + # Evaluate matfun fA = dense_funm(B) e1 = np.eye(len(fA))[0, :] return length**2 * linalg.inner(e1, fA @ e1) diff --git a/tests/test_decomp/test_bidiag.py b/tests/test_decomp/test_bidiag.py index f01ae5d..a523fa9 100644 --- a/tests/test_decomp/test_bidiag.py +++ b/tests/test_decomp/test_bidiag.py @@ -32,7 +32,7 @@ def vA(v): return v @ A algorithm = decomp.bidiag(order, materialize=True) - (U, V), B, res, ln = algorithm(Av, vA, v0) + (U, V), B, res, ln = algorithm(Av, v0) test_util.assert_columns_orthonormal(U) test_util.assert_columns_orthonormal(V) @@ -51,9 +51,9 @@ def test_error_too_high_depth(nrows, ncols, num_significant_singular_vals): A = make_A(nrows, ncols, num_significant_singular_vals) max_depth = min(nrows, ncols) - 1 - with testing.raises(ValueError, match=""): + with testing.raises(ValueError, match="exceeds"): alg = decomp.bidiag(max_depth + 1, materialize=False) - _ = alg(lambda v: A @ v, lambda v: A.T @ v, A[0]) + _ = alg(lambda v: A @ v, A[0]) @testing.parametrize("nrows", [5]) @@ -63,9 +63,9 @@ def test_error_too_low_depth(nrows, ncols, num_significant_singular_vals): """Assert that a ValueError is raised when the depth is negative.""" A = make_A(nrows, ncols, num_significant_singular_vals) min_depth = 0 - with testing.raises(ValueError, match=""): + with testing.raises(ValueError, match="exceeds"): alg = decomp.bidiag(min_depth - 1, materialize=False) - _ = alg(lambda v: A @ v, lambda v: A.T @ v, A[0]) + _ = alg(lambda v: A @ v, A[0]) @testing.parametrize("nrows", [15]) @@ -84,7 +84,7 @@ def vA(v): return v @ A algorithm = decomp.bidiag(0, materialize=False) - (U, V), (d_m, e_m), res, ln = algorithm(Av, vA, v0) + (U, V), (d_m, e_m), res, ln = algorithm(Av, v0) assert np.shape(U) == (nrows, 1) assert np.shape(V) == (ncols, 1) diff --git a/tests/test_eig/test_svd_partial.py b/tests/test_eig/test_svd_partial.py index c852e02..32e2c16 100644 --- a/tests/test_eig/test_svd_partial.py +++ b/tests/test_eig/test_svd_partial.py @@ -29,12 +29,9 @@ def test_equal_to_linalg_svd(A): def Av(v): return A @ v - def vA(v): - return v @ A - v0 = np.ones((ncols,)) v0 /= linalg.vector_norm(v0) - U, S, Vt = eig.svd_partial(v0, depth, Av, vA) + U, S, Vt = eig.svd_partial(v0, depth, Av) U_, S_, Vt_ = linalg.svd(A, full_matrices=False) tols_decomp = {"atol": 1e-5, "rtol": 1e-5} diff --git a/tests/test_funm/test_integrand_funm_product_logdet.py b/tests/test_funm/test_integrand_funm_product_logdet.py index c3ceda0..bd124ac 100644 --- a/tests/test_funm/test_integrand_funm_product_logdet.py +++ b/tests/test_funm/test_integrand_funm_product_logdet.py @@ -25,16 +25,13 @@ def test_logdet_product(nrows, ncols, num_significant_singular_vals, order): def matvec(x): return {"fx": A @ x["fx"]} - def vecmat(x): - return {"fx": x["fx"] @ A} - x_like = {"fx": np.ones((ncols,), dtype=float)} fun = stochtrace.sampler_normal(x_like, num=400) bidiag = decomp.bidiag(order) problem = funm.integrand_funm_product_logdet(bidiag) estimate = stochtrace.estimator(problem, fun) - received = estimate((matvec, vecmat), key) + received = estimate(matvec, key) expected = linalg.slogdet(A.T @ A)[1] print_if_assert_fails = ("error", np.abs(received - expected), "target:", expected) @@ -61,7 +58,7 @@ def test_logdet_product_exact_for_full_order_lanczos(n): x = prng.normal(prng.prng_key(seed=1), shape=(n,)) + 1 # Compute v^\top @ log(A) @ v via Lanczos - received = integrand((lambda v: A @ v, lambda v: v @ A), x) + received = integrand((lambda v: A @ v), x) # Compute the "true" value of v^\top @ log(A) @ v via eigenvalues eigvals, eigvecs = linalg.eigh(A.T @ A) diff --git a/tests/test_funm/test_integrand_funm_product_schatten_norm.py b/tests/test_funm/test_integrand_funm_product_schatten_norm.py index 34a9579..139a1d9 100644 --- a/tests/test_funm/test_integrand_funm_product_schatten_norm.py +++ b/tests/test_funm/test_integrand_funm_product_schatten_norm.py @@ -32,7 +32,7 @@ def test_schatten_norm(nrows, ncols, num_significant_singular_vals, order, power estimate = stochtrace.estimator(integrand, sampler) key = prng.prng_key(1) - received = estimate((lambda v: A @ v, lambda v: A.T @ v), key) + received = estimate(lambda v: A @ v, key) print_if_assert_fails = ("error", np.abs(received - expected), "target:", expected) assert np.allclose(received, expected, atol=1e-2, rtol=1e-2), print_if_assert_fails diff --git a/tutorials/1_log_determinants.py b/tutorials/1_log_determinants.py index 92c52cd..1c5bca8 100644 --- a/tutorials/1_log_determinants.py +++ b/tutorials/1_log_determinants.py @@ -48,24 +48,22 @@ def matvec(x): A /= nrows**2 -def matvec_r(x): +def matvec_half(x): """Compute a matrix-vector product.""" return A @ x -def vecmat_l(x): - """Compute a vector-matrix product.""" - return x @ A - - order = 3 bidiag = decomp.bidiag(order) problem = funm.integrand_funm_product_logdet(bidiag) sampler = stochtrace.sampler_normal(x_like, num=1_000) estimator = stochtrace.estimator(problem, sampler=sampler) -logdet = estimator((matvec_r, vecmat_l), jax.random.PRNGKey(1)) +logdet = estimator(matvec_half, jax.random.PRNGKey(1)) print(logdet) +# Internally, Matfree uses JAX's vector-Jacobian products to +# turn the matrix-vector product into a vector-matrix product. +# # For comparison: print(jnp.linalg.slogdet(A.T @ A))