From 989d87c828c0bfd022f666579b2d239589dcb83d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Tue, 7 Jan 2025 15:46:26 +0100 Subject: [PATCH 1/2] Adopt jax.tree.* and rename backend.tree_util to backend.tree --- matfree/backend/{tree_util.py => tree.py} | 10 ++++---- matfree/decomp.py | 8 +++---- matfree/funm.py | 10 ++++---- matfree/stochtrace.py | 24 +++++++++---------- matfree/test_util.py | 4 ++-- tests/test_decomp/test_consistency.py | 4 ++-- tests/test_decomp/test_tridiag_sym_adjoint.py | 6 ++--- tests/test_stochtrace/test_diagonal.py | 6 ++--- .../test_frobeniusnorm_squared.py | 4 ++-- tests/test_stochtrace/test_trace.py | 4 ++-- .../test_trace_and_diagonal.py | 4 ++-- tests/test_stochtrace/test_wrap_moments.py | 4 ++-- tutorials/2_pytree_logdeterminants.py | 8 +++---- 13 files changed, 48 insertions(+), 48 deletions(-) rename matfree/backend/{tree_util.py => tree.py} (56%) diff --git a/matfree/backend/tree_util.py b/matfree/backend/tree.py similarity index 56% rename from matfree/backend/tree_util.py rename to matfree/backend/tree.py index 605c679..14fa0ae 100644 --- a/matfree/backend/tree_util.py +++ b/matfree/backend/tree.py @@ -5,11 +5,11 @@ def tree_all(tree, /): - return jax.tree_util.tree_all(tree) + return jax.tree.all(tree) def tree_map(func, tree, *rest): - return jax.tree_util.tree_map(func, tree, *rest) + return jax.tree.map(func, tree, *rest) def ravel_pytree(tree, /): @@ -17,12 +17,12 @@ def ravel_pytree(tree, /): def tree_leaves(tree, /): - return jax.tree_util.tree_leaves(tree) + return jax.tree.leaves(tree) def tree_structure(tree, /): - return jax.tree_util.tree_structure(tree) + return jax.tree.structure(tree) def partial_pytree(func, /): - return jax.tree_util.Partial(func) + return jax.tree.Partial(func) diff --git a/matfree/decomp.py b/matfree/decomp.py index ece5440..3e78461 100644 --- a/matfree/decomp.py +++ b/matfree/decomp.py @@ -8,7 +8,7 @@ [matfree.funm][matfree.funm]. """ -from matfree.backend import containers, control_flow, func, linalg, np, tree_util +from matfree.backend import containers, control_flow, func, linalg, np, tree from matfree.backend.typing import Array, Callable, Union @@ -304,7 +304,7 @@ def adjoint_step(xi_and_lambda, inputs): ) # Compute the gradients - grad_matvec = tree_util.tree_map(lambda s: np.sum(s, axis=0), grad_summands) + grad_matvec = tree.tree_map(lambda s: np.sum(s, axis=0), grad_summands) grad_initvec = ((lambda_1.T @ xs[0]) * xs[0] - lambda_1) / initvec_norm # Return values @@ -477,7 +477,7 @@ def lower(m): lambda_k = dr + Q @ eta Lambda = np.zeros_like(Q) Gamma = np.zeros_like(dQ.T @ Q) - dp = tree_util.tree_map(np.zeros_like, params) + dp = tree.tree_map(np.zeros_like, params) # Prepare more auxiliary matrices Pi_xi = dQ.T + linalg.outer(eta, r) @@ -562,7 +562,7 @@ def _hessenberg_adjoint_step( # Transposed matvec and parameter-gradient in a single matvec _, vjp = func.vjp(lambda u, v: matvec(u, *v), q, params) vecmat_lambda, dp_increment = vjp(lambda_k) - dp = tree_util.tree_map(lambda g, h: g + h, dp, dp_increment) + dp = tree.tree_map(lambda g, h: g + h, dp, dp_increment) # Solve for (Gamma + Gamma.T) e_K tmp = lower_mask * (Pi_gamma - vecmat_lambda @ Q) diff --git a/matfree/funm.py b/matfree/funm.py index d12a8bb..327827c 100644 --- a/matfree/funm.py +++ b/matfree/funm.py @@ -37,7 +37,7 @@ """ -from matfree.backend import containers, control_flow, func, linalg, np, tree_util +from matfree.backend import containers, control_flow, func, linalg, np, tree from matfree.backend.typing import Array, Callable @@ -212,14 +212,14 @@ def integrand_funm_sym(dense_funm, tridiag_sym, /): """ def quadform(matvec, v0, *parameters): - v0_flat, v_unflatten = tree_util.ravel_pytree(v0) + v0_flat, v_unflatten = tree.ravel_pytree(v0) length = linalg.vector_norm(v0_flat) v0_flat /= length def matvec_flat(v_flat, *p): v = v_unflatten(v_flat) Av = matvec(v, *p) - flat, unflatten = tree_util.ravel_pytree(Av) + flat, unflatten = tree.ravel_pytree(Av) return flat _, dense, *_ = tridiag_sym(matvec_flat, v0_flat, *parameters) @@ -260,14 +260,14 @@ def integrand_funm_product(dense_funm, algorithm, /): """ def quadform(matvec, v0, *parameters): - v0_flat, v_unflatten = tree_util.ravel_pytree(v0) + v0_flat, v_unflatten = tree.ravel_pytree(v0) length = linalg.vector_norm(v0_flat) v0_flat /= length def matvec_flat(v_flat, *p): v = v_unflatten(v_flat) Av = matvec(v, *p) - flat, _unflatten = tree_util.ravel_pytree(Av) + flat, _unflatten = tree.ravel_pytree(Av) return flat # Decompose into orthogonal-bidiag-orthogonal diff --git a/matfree/stochtrace.py b/matfree/stochtrace.py index 1704098..e07f353 100644 --- a/matfree/stochtrace.py +++ b/matfree/stochtrace.py @@ -1,6 +1,6 @@ """Stochastic estimation of traces, diagonals, and more.""" -from matfree.backend import func, linalg, np, prng, tree_util +from matfree.backend import func, linalg, np, prng, tree from matfree.backend.typing import Callable @@ -30,7 +30,7 @@ def estimator(integrand: Callable, /, sampler: Callable) -> Callable: def estimate(matvecs, key, *parameters): samples = sampler(key) Qs = func.vmap(lambda vec: integrand(matvecs, vec, *parameters))(samples) - return tree_util.tree_map(lambda s: np.mean(s, axis=0), Qs) + return tree.tree_map(lambda s: np.mean(s, axis=0), Qs) return estimate @@ -49,8 +49,8 @@ def integrand_diagonal(): def integrand(matvec, v, *parameters): Qv = matvec(v, *parameters) - v_flat, unflatten = tree_util.ravel_pytree(v) - Qv_flat, _unflatten = tree_util.ravel_pytree(Qv) + v_flat, unflatten = tree.ravel_pytree(v) + Qv_flat, _unflatten = tree.ravel_pytree(Qv) return unflatten(v_flat * Qv_flat) return integrand @@ -61,8 +61,8 @@ def integrand_trace(): def integrand(matvec, v, *parameters): Qv = matvec(v, *parameters) - v_flat, unflatten = tree_util.ravel_pytree(v) - Qv_flat, _unflatten = tree_util.ravel_pytree(Qv) + v_flat, unflatten = tree.ravel_pytree(v) + Qv_flat, _unflatten = tree.ravel_pytree(Qv) return linalg.inner(v_flat, Qv_flat) return integrand @@ -73,8 +73,8 @@ def integrand_trace_and_diagonal(): def integrand(matvec, v, *parameters): Qv = matvec(v, *parameters) - v_flat, unflatten = tree_util.ravel_pytree(v) - Qv_flat, _unflatten = tree_util.ravel_pytree(Qv) + v_flat, unflatten = tree.ravel_pytree(v) + Qv_flat, _unflatten = tree.ravel_pytree(Qv) trace_form = linalg.inner(v_flat, Qv_flat) diagonal_form = unflatten(v_flat * Qv_flat) return {"trace": trace_form, "diagonal": diagonal_form} @@ -87,7 +87,7 @@ def integrand_frobeniusnorm_squared(): def integrand(matvec, vec, *parameters): x = matvec(vec, *parameters) - v_flat, unflatten = tree_util.ravel_pytree(x) + v_flat, unflatten = tree.ravel_pytree(x) return linalg.inner(v_flat, v_flat) return integrand @@ -119,10 +119,10 @@ def integrand_wrap_moments(integrand, /, moments): def integrand_wrapped(vec, *parameters): Qs = integrand(vec, *parameters) - return tree_util.tree_map(moment_fun, Qs) + return tree.tree_map(moment_fun, Qs) def moment_fun(x, /): - return tree_util.tree_map(lambda m: x**m, moments) + return tree.tree_map(lambda m: x**m, moments) return integrand_wrapped @@ -138,7 +138,7 @@ def sampler_rademacher(*args_like, num): def _sampler_from_jax_random(sampler, *args_like, num): - x_flat, unflatten = tree_util.ravel_pytree(*args_like) + x_flat, unflatten = tree.ravel_pytree(*args_like) def sample(key): samples = sampler(key, shape=(num, *x_flat.shape), dtype=x_flat.dtype) diff --git a/matfree/test_util.py b/matfree/test_util.py index dfe9557..e48f9d8 100644 --- a/matfree/test_util.py +++ b/matfree/test_util.py @@ -1,6 +1,6 @@ """Test utilities.""" -from matfree.backend import linalg, np, prng, tree_util +from matfree.backend import linalg, np, prng def symmetric_matrix_from_eigenvalues(eigvals, /): @@ -47,7 +47,7 @@ def to_dense_tridiag_sym(d, e, /): def tree_random_like(key, tree, *, generate_func=prng.normal): """Fill a tree with random values.""" - flat, unflatten = tree_util.ravel_pytree(tree) + flat, unflatten = tree.ravel_pytree(tree) flat_like = generate_func(key, shape=flat.shape, dtype=flat.dtype) return unflatten(flat_like) diff --git a/tests/test_decomp/test_consistency.py b/tests/test_decomp/test_consistency.py index c054c70..352622a 100644 --- a/tests/test_decomp/test_consistency.py +++ b/tests/test_decomp/test_consistency.py @@ -1,7 +1,7 @@ """Ensure that bidiag, tridiag, etc. have consistent signatures.""" from matfree import decomp -from matfree.backend import np, prng, testing, tree_util +from matfree.backend import np, prng, testing, tree def case_method_bidiag(): @@ -38,7 +38,7 @@ def test_output_shape_as_expected(nrows, num_matvecs, method): Us, B, res, ln = algorithm(lambda v: A @ v, v0) # Normalise the Us to always have a list - Us = tree_util.tree_leaves(Us) + Us = tree.tree_leaves(Us) for U in Us: assert np.shape(U) == (nrows, num_matvecs) diff --git a/tests/test_decomp/test_tridiag_sym_adjoint.py b/tests/test_decomp/test_tridiag_sym_adjoint.py index 456fa2e..a73bbbf 100644 --- a/tests/test_decomp/test_tridiag_sym_adjoint.py +++ b/tests/test_decomp/test_tridiag_sym_adjoint.py @@ -1,7 +1,7 @@ """Test the adjoint of tri-diagonalisation.""" from matfree import decomp, test_util -from matfree.backend import func, linalg, np, prng, testing, tree_util +from matfree.backend import func, linalg, np, prng, testing, tree @testing.parametrize("reortho", ["full", "none"]) @@ -19,14 +19,14 @@ def matvec(s, p): vector = prng.normal(prng.prng_key(1), shape=(n,)) # Flatten the inputs - flat, unflatten = tree_util.ravel_pytree((vector, params)) + flat, unflatten = tree.ravel_pytree((vector, params)) # Construct a vector-to-vector decomposition function def decompose(f, *, custom_vjp): kwargs = {"reortho": reortho, "custom_vjp": custom_vjp, "materialize": False} algorithm = decomp.tridiag_sym(krylov_num_matvecs, **kwargs) output = algorithm(matvec, *unflatten(f)) - return tree_util.ravel_pytree(output)[0] + return tree.ravel_pytree(output)[0] # Construct the two implementations reference = func.jit(func.partial(decompose, custom_vjp=False)) diff --git a/tests/test_stochtrace/test_diagonal.py b/tests/test_stochtrace/test_diagonal.py index ecd8cd8..64befac 100644 --- a/tests/test_stochtrace/test_diagonal.py +++ b/tests/test_stochtrace/test_diagonal.py @@ -1,7 +1,7 @@ """Test the diagonal estimation.""" from matfree import stochtrace -from matfree.backend import func, linalg, np, prng, tree_util +from matfree.backend import func, linalg, np, prng, tree def test_diagonal(): @@ -20,7 +20,7 @@ def fun(x): _, jvp = func.linearize(fun, args_like) J = func.jacfwd(fun)(args_like)["params"] - expected = tree_util.tree_map(linalg.diagonal, J) + expected = tree.tree_map(linalg.diagonal, J) # Estimate the matrix function problem = stochtrace.integrand_diagonal() @@ -31,4 +31,4 @@ def fun(x): def compare(a, b): return np.allclose(a, b, rtol=1e-2) - assert tree_util.tree_all(tree_util.tree_map(compare, received, expected)) + assert tree.tree_all(tree.tree_map(compare, received, expected)) diff --git a/tests/test_stochtrace/test_frobeniusnorm_squared.py b/tests/test_stochtrace/test_frobeniusnorm_squared.py index ce48224..e87d975 100644 --- a/tests/test_stochtrace/test_frobeniusnorm_squared.py +++ b/tests/test_stochtrace/test_frobeniusnorm_squared.py @@ -1,7 +1,7 @@ """Test the estimation of squared Frobenius-norms.""" from matfree import stochtrace -from matfree.backend import func, linalg, np, prng, tree_util +from matfree.backend import func, linalg, np, prng, tree def test_frobeniusnorm_squared(): @@ -18,7 +18,7 @@ def fun(x): x0 = prng.uniform(key, shape=(4,)) # random lin. point args_like = {"params": x0} _, jvp = func.linearize(fun, args_like) - [J] = tree_util.tree_leaves(func.jacfwd(fun)(args_like)) + [J] = tree.tree_leaves(func.jacfwd(fun)(args_like)) expected = linalg.trace(J.T @ J) # Estimate the matrix function diff --git a/tests/test_stochtrace/test_trace.py b/tests/test_stochtrace/test_trace.py index 1545e30..c1aa67d 100644 --- a/tests/test_stochtrace/test_trace.py +++ b/tests/test_stochtrace/test_trace.py @@ -1,7 +1,7 @@ """Test trace estimation.""" from matfree import stochtrace -from matfree.backend import func, linalg, np, prng, tree_util +from matfree.backend import func, linalg, np, prng, tree def test_trace(): @@ -30,4 +30,4 @@ def fun(x): def compare(a, b): return np.allclose(a, b, rtol=1e-2) - assert tree_util.tree_all(tree_util.tree_map(compare, received, expected)) + assert tree.tree_all(tree.tree_map(compare, received, expected)) diff --git a/tests/test_stochtrace/test_trace_and_diagonal.py b/tests/test_stochtrace/test_trace_and_diagonal.py index ef61393..e1a5b8c 100644 --- a/tests/test_stochtrace/test_trace_and_diagonal.py +++ b/tests/test_stochtrace/test_trace_and_diagonal.py @@ -1,7 +1,7 @@ """Test joint trace and diagonal estimation.""" from matfree import stochtrace -from matfree.backend import func, np, prng, tree_util +from matfree.backend import func, np, prng, tree def test_trace_and_diagonal(): @@ -39,4 +39,4 @@ def fun(x): expected = {"trace": expected_trace, "diagonal": expected_diagonal} - assert tree_util.tree_all(tree_util.tree_map(np.allclose, received, expected)) + assert tree.tree_all(tree.tree_map(np.allclose, received, expected)) diff --git a/tests/test_stochtrace/test_wrap_moments.py b/tests/test_stochtrace/test_wrap_moments.py index c6d2340..e831690 100644 --- a/tests/test_stochtrace/test_wrap_moments.py +++ b/tests/test_stochtrace/test_wrap_moments.py @@ -1,7 +1,7 @@ """Test the estimation with multiple statistics.""" from matfree import stochtrace -from matfree.backend import func, linalg, np, prng, testing, tree_util +from matfree.backend import func, linalg, np, prng, testing, tree def test_yields_correct_tree_structure(): @@ -32,7 +32,7 @@ def fun(x): expected = { "params": {"moment_1st": irrelevant_value, "moment_2nd": irrelevant_value} } - assert tree_util.tree_structure(received) == tree_util.tree_structure(expected) + assert tree.tree_structure(received) == tree.tree_structure(expected) @testing.fixture(name="key") diff --git a/tutorials/2_pytree_logdeterminants.py b/tutorials/2_pytree_logdeterminants.py index 937c855..9e92366 100644 --- a/tutorials/2_pytree_logdeterminants.py +++ b/tutorials/2_pytree_logdeterminants.py @@ -30,9 +30,9 @@ def testfunc(x): f0, jvp = jax.linearize(testfunc, x0) _f0, vjp = jax.vjp(testfunc, x0) -print(jax.tree_util.tree_map(jnp.shape, f0)) -print(jax.tree_util.tree_map(jnp.shape, jvp(x0))) -print(jax.tree_util.tree_map(jnp.shape, vjp(f0))) +print(jax.tree.map(jnp.shape, f0)) +print(jax.tree.map(jnp.shape, jvp(x0))) +print(jax.tree.map(jnp.shape, vjp(f0))) # Use the same API as if the matrix-vector product were array-valued. @@ -46,7 +46,7 @@ def fun(fx, /): r"""Matrix-vector product with $J J^\top + \alpha I$.""" vjp_eval = vjp(fx) matvec_eval = jvp(*vjp_eval) - return jax.tree_util.tree_map(lambda x, y: x + alpha * y, matvec_eval, fx) + return jax.tree.map(lambda x, y: x + alpha * y, matvec_eval, fx) return fun From 3d17d62da6233e7bf9c4838428e7368e9e35b969 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicholas=20Kr=C3=A4mer?= Date: Tue, 7 Jan 2025 15:52:17 +0100 Subject: [PATCH 2/2] Fix some naming-related errors --- matfree/backend/tree.py | 20 ++++++++++---------- matfree/test_util.py | 6 +++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/matfree/backend/tree.py b/matfree/backend/tree.py index 14fa0ae..ddc786b 100644 --- a/matfree/backend/tree.py +++ b/matfree/backend/tree.py @@ -4,24 +4,24 @@ import jax.flatten_util -def tree_all(tree, /): - return jax.tree.all(tree) +def tree_all(pytree, /): + return jax.tree.all(pytree) -def tree_map(func, tree, *rest): - return jax.tree.map(func, tree, *rest) +def tree_map(func, pytree, *rest): + return jax.tree.map(func, pytree, *rest) -def ravel_pytree(tree, /): - return jax.flatten_util.ravel_pytree(tree) +def ravel_pytree(pytree, /): + return jax.flatten_util.ravel_pytree(pytree) -def tree_leaves(tree, /): - return jax.tree.leaves(tree) +def tree_leaves(pytree, /): + return jax.tree.leaves(pytree) -def tree_structure(tree, /): - return jax.tree.structure(tree) +def tree_structure(pytree, /): + return jax.tree.structure(pytree) def partial_pytree(func, /): diff --git a/matfree/test_util.py b/matfree/test_util.py index e48f9d8..eec6762 100644 --- a/matfree/test_util.py +++ b/matfree/test_util.py @@ -1,6 +1,6 @@ """Test utilities.""" -from matfree.backend import linalg, np, prng +from matfree.backend import linalg, np, prng, tree def symmetric_matrix_from_eigenvalues(eigvals, /): @@ -45,9 +45,9 @@ def to_dense_tridiag_sym(d, e, /): return diag + offdiag1 + offdiag2 -def tree_random_like(key, tree, *, generate_func=prng.normal): +def tree_random_like(key, pytree, *, generate_func=prng.normal): """Fill a tree with random values.""" - flat, unflatten = tree.ravel_pytree(tree) + flat, unflatten = tree.ravel_pytree(pytree) flat_like = generate_func(key, shape=flat.shape, dtype=flat.dtype) return unflatten(flat_like)