Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adopt jax.tree.* and rename backend.tree_util to backend.tree #231

Merged
merged 2 commits into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions matfree/backend/tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""PyTree utilities."""

import jax
import jax.flatten_util


def tree_all(pytree, /):
return jax.tree.all(pytree)


def tree_map(func, pytree, *rest):
return jax.tree.map(func, pytree, *rest)


def ravel_pytree(pytree, /):
return jax.flatten_util.ravel_pytree(pytree)


def tree_leaves(pytree, /):
return jax.tree.leaves(pytree)


def tree_structure(pytree, /):
return jax.tree.structure(pytree)


def partial_pytree(func, /):
return jax.tree.Partial(func)
28 changes: 0 additions & 28 deletions matfree/backend/tree_util.py

This file was deleted.

8 changes: 4 additions & 4 deletions matfree/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
[matfree.funm][matfree.funm].
"""

from matfree.backend import containers, control_flow, func, linalg, np, tree_util
from matfree.backend import containers, control_flow, func, linalg, np, tree
from matfree.backend.typing import Array, Callable, Union


Expand Down Expand Up @@ -304,7 +304,7 @@ def adjoint_step(xi_and_lambda, inputs):
)

# Compute the gradients
grad_matvec = tree_util.tree_map(lambda s: np.sum(s, axis=0), grad_summands)
grad_matvec = tree.tree_map(lambda s: np.sum(s, axis=0), grad_summands)
grad_initvec = ((lambda_1.T @ xs[0]) * xs[0] - lambda_1) / initvec_norm

# Return values
Expand Down Expand Up @@ -477,7 +477,7 @@ def lower(m):
lambda_k = dr + Q @ eta
Lambda = np.zeros_like(Q)
Gamma = np.zeros_like(dQ.T @ Q)
dp = tree_util.tree_map(np.zeros_like, params)
dp = tree.tree_map(np.zeros_like, params)

# Prepare more auxiliary matrices
Pi_xi = dQ.T + linalg.outer(eta, r)
Expand Down Expand Up @@ -562,7 +562,7 @@ def _hessenberg_adjoint_step(
# Transposed matvec and parameter-gradient in a single matvec
_, vjp = func.vjp(lambda u, v: matvec(u, *v), q, params)
vecmat_lambda, dp_increment = vjp(lambda_k)
dp = tree_util.tree_map(lambda g, h: g + h, dp, dp_increment)
dp = tree.tree_map(lambda g, h: g + h, dp, dp_increment)

# Solve for (Gamma + Gamma.T) e_K
tmp = lower_mask * (Pi_gamma - vecmat_lambda @ Q)
Expand Down
10 changes: 5 additions & 5 deletions matfree/funm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

"""

from matfree.backend import containers, control_flow, func, linalg, np, tree_util
from matfree.backend import containers, control_flow, func, linalg, np, tree
from matfree.backend.typing import Array, Callable


Expand Down Expand Up @@ -212,14 +212,14 @@ def integrand_funm_sym(dense_funm, tridiag_sym, /):
"""

def quadform(matvec, v0, *parameters):
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)
v0_flat, v_unflatten = tree.ravel_pytree(v0)
length = linalg.vector_norm(v0_flat)
v0_flat /= length

def matvec_flat(v_flat, *p):
v = v_unflatten(v_flat)
Av = matvec(v, *p)
flat, unflatten = tree_util.ravel_pytree(Av)
flat, unflatten = tree.ravel_pytree(Av)
return flat

_, dense, *_ = tridiag_sym(matvec_flat, v0_flat, *parameters)
Expand Down Expand Up @@ -260,14 +260,14 @@ def integrand_funm_product(dense_funm, algorithm, /):
"""

def quadform(matvec, v0, *parameters):
v0_flat, v_unflatten = tree_util.ravel_pytree(v0)
v0_flat, v_unflatten = tree.ravel_pytree(v0)
length = linalg.vector_norm(v0_flat)
v0_flat /= length

def matvec_flat(v_flat, *p):
v = v_unflatten(v_flat)
Av = matvec(v, *p)
flat, _unflatten = tree_util.ravel_pytree(Av)
flat, _unflatten = tree.ravel_pytree(Av)
return flat

# Decompose into orthogonal-bidiag-orthogonal
Expand Down
24 changes: 12 additions & 12 deletions matfree/stochtrace.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Stochastic estimation of traces, diagonals, and more."""

from matfree.backend import func, linalg, np, prng, tree_util
from matfree.backend import func, linalg, np, prng, tree
from matfree.backend.typing import Callable


Expand Down Expand Up @@ -30,7 +30,7 @@ def estimator(integrand: Callable, /, sampler: Callable) -> Callable:
def estimate(matvecs, key, *parameters):
samples = sampler(key)
Qs = func.vmap(lambda vec: integrand(matvecs, vec, *parameters))(samples)
return tree_util.tree_map(lambda s: np.mean(s, axis=0), Qs)
return tree.tree_map(lambda s: np.mean(s, axis=0), Qs)

return estimate

Expand All @@ -49,8 +49,8 @@ def integrand_diagonal():

def integrand(matvec, v, *parameters):
Qv = matvec(v, *parameters)
v_flat, unflatten = tree_util.ravel_pytree(v)
Qv_flat, _unflatten = tree_util.ravel_pytree(Qv)
v_flat, unflatten = tree.ravel_pytree(v)
Qv_flat, _unflatten = tree.ravel_pytree(Qv)
return unflatten(v_flat * Qv_flat)

return integrand
Expand All @@ -61,8 +61,8 @@ def integrand_trace():

def integrand(matvec, v, *parameters):
Qv = matvec(v, *parameters)
v_flat, unflatten = tree_util.ravel_pytree(v)
Qv_flat, _unflatten = tree_util.ravel_pytree(Qv)
v_flat, unflatten = tree.ravel_pytree(v)
Qv_flat, _unflatten = tree.ravel_pytree(Qv)
return linalg.inner(v_flat, Qv_flat)

return integrand
Expand All @@ -73,8 +73,8 @@ def integrand_trace_and_diagonal():

def integrand(matvec, v, *parameters):
Qv = matvec(v, *parameters)
v_flat, unflatten = tree_util.ravel_pytree(v)
Qv_flat, _unflatten = tree_util.ravel_pytree(Qv)
v_flat, unflatten = tree.ravel_pytree(v)
Qv_flat, _unflatten = tree.ravel_pytree(Qv)
trace_form = linalg.inner(v_flat, Qv_flat)
diagonal_form = unflatten(v_flat * Qv_flat)
return {"trace": trace_form, "diagonal": diagonal_form}
Expand All @@ -87,7 +87,7 @@ def integrand_frobeniusnorm_squared():

def integrand(matvec, vec, *parameters):
x = matvec(vec, *parameters)
v_flat, unflatten = tree_util.ravel_pytree(x)
v_flat, unflatten = tree.ravel_pytree(x)
return linalg.inner(v_flat, v_flat)

return integrand
Expand Down Expand Up @@ -119,10 +119,10 @@ def integrand_wrap_moments(integrand, /, moments):

def integrand_wrapped(vec, *parameters):
Qs = integrand(vec, *parameters)
return tree_util.tree_map(moment_fun, Qs)
return tree.tree_map(moment_fun, Qs)

def moment_fun(x, /):
return tree_util.tree_map(lambda m: x**m, moments)
return tree.tree_map(lambda m: x**m, moments)

return integrand_wrapped

Expand All @@ -138,7 +138,7 @@ def sampler_rademacher(*args_like, num):


def _sampler_from_jax_random(sampler, *args_like, num):
x_flat, unflatten = tree_util.ravel_pytree(*args_like)
x_flat, unflatten = tree.ravel_pytree(*args_like)

def sample(key):
samples = sampler(key, shape=(num, *x_flat.shape), dtype=x_flat.dtype)
Expand Down
6 changes: 3 additions & 3 deletions matfree/test_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test utilities."""

from matfree.backend import linalg, np, prng, tree_util
from matfree.backend import linalg, np, prng, tree


def symmetric_matrix_from_eigenvalues(eigvals, /):
Expand Down Expand Up @@ -45,9 +45,9 @@ def to_dense_tridiag_sym(d, e, /):
return diag + offdiag1 + offdiag2


def tree_random_like(key, tree, *, generate_func=prng.normal):
def tree_random_like(key, pytree, *, generate_func=prng.normal):
"""Fill a tree with random values."""
flat, unflatten = tree_util.ravel_pytree(tree)
flat, unflatten = tree.ravel_pytree(pytree)
flat_like = generate_func(key, shape=flat.shape, dtype=flat.dtype)
return unflatten(flat_like)

Expand Down
4 changes: 2 additions & 2 deletions tests/test_decomp/test_consistency.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Ensure that bidiag, tridiag, etc. have consistent signatures."""

from matfree import decomp
from matfree.backend import np, prng, testing, tree_util
from matfree.backend import np, prng, testing, tree


def case_method_bidiag():
Expand Down Expand Up @@ -38,7 +38,7 @@ def test_output_shape_as_expected(nrows, num_matvecs, method):
Us, B, res, ln = algorithm(lambda v: A @ v, v0)

# Normalise the Us to always have a list
Us = tree_util.tree_leaves(Us)
Us = tree.tree_leaves(Us)

for U in Us:
assert np.shape(U) == (nrows, num_matvecs)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_decomp/test_tridiag_sym_adjoint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test the adjoint of tri-diagonalisation."""

from matfree import decomp, test_util
from matfree.backend import func, linalg, np, prng, testing, tree_util
from matfree.backend import func, linalg, np, prng, testing, tree


@testing.parametrize("reortho", ["full", "none"])
Expand All @@ -19,14 +19,14 @@ def matvec(s, p):
vector = prng.normal(prng.prng_key(1), shape=(n,))

# Flatten the inputs
flat, unflatten = tree_util.ravel_pytree((vector, params))
flat, unflatten = tree.ravel_pytree((vector, params))

# Construct a vector-to-vector decomposition function
def decompose(f, *, custom_vjp):
kwargs = {"reortho": reortho, "custom_vjp": custom_vjp, "materialize": False}
algorithm = decomp.tridiag_sym(krylov_num_matvecs, **kwargs)
output = algorithm(matvec, *unflatten(f))
return tree_util.ravel_pytree(output)[0]
return tree.ravel_pytree(output)[0]

# Construct the two implementations
reference = func.jit(func.partial(decompose, custom_vjp=False))
Expand Down
6 changes: 3 additions & 3 deletions tests/test_stochtrace/test_diagonal.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test the diagonal estimation."""

from matfree import stochtrace
from matfree.backend import func, linalg, np, prng, tree_util
from matfree.backend import func, linalg, np, prng, tree


def test_diagonal():
Expand All @@ -20,7 +20,7 @@ def fun(x):
_, jvp = func.linearize(fun, args_like)
J = func.jacfwd(fun)(args_like)["params"]

expected = tree_util.tree_map(linalg.diagonal, J)
expected = tree.tree_map(linalg.diagonal, J)

# Estimate the matrix function
problem = stochtrace.integrand_diagonal()
Expand All @@ -31,4 +31,4 @@ def fun(x):
def compare(a, b):
return np.allclose(a, b, rtol=1e-2)

assert tree_util.tree_all(tree_util.tree_map(compare, received, expected))
assert tree.tree_all(tree.tree_map(compare, received, expected))
4 changes: 2 additions & 2 deletions tests/test_stochtrace/test_frobeniusnorm_squared.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test the estimation of squared Frobenius-norms."""

from matfree import stochtrace
from matfree.backend import func, linalg, np, prng, tree_util
from matfree.backend import func, linalg, np, prng, tree


def test_frobeniusnorm_squared():
Expand All @@ -18,7 +18,7 @@ def fun(x):
x0 = prng.uniform(key, shape=(4,)) # random lin. point
args_like = {"params": x0}
_, jvp = func.linearize(fun, args_like)
[J] = tree_util.tree_leaves(func.jacfwd(fun)(args_like))
[J] = tree.tree_leaves(func.jacfwd(fun)(args_like))
expected = linalg.trace(J.T @ J)

# Estimate the matrix function
Expand Down
4 changes: 2 additions & 2 deletions tests/test_stochtrace/test_trace.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test trace estimation."""

from matfree import stochtrace
from matfree.backend import func, linalg, np, prng, tree_util
from matfree.backend import func, linalg, np, prng, tree


def test_trace():
Expand Down Expand Up @@ -30,4 +30,4 @@ def fun(x):
def compare(a, b):
return np.allclose(a, b, rtol=1e-2)

assert tree_util.tree_all(tree_util.tree_map(compare, received, expected))
assert tree.tree_all(tree.tree_map(compare, received, expected))
4 changes: 2 additions & 2 deletions tests/test_stochtrace/test_trace_and_diagonal.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test joint trace and diagonal estimation."""

from matfree import stochtrace
from matfree.backend import func, np, prng, tree_util
from matfree.backend import func, np, prng, tree


def test_trace_and_diagonal():
Expand Down Expand Up @@ -39,4 +39,4 @@ def fun(x):

expected = {"trace": expected_trace, "diagonal": expected_diagonal}

assert tree_util.tree_all(tree_util.tree_map(np.allclose, received, expected))
assert tree.tree_all(tree.tree_map(np.allclose, received, expected))
4 changes: 2 additions & 2 deletions tests/test_stochtrace/test_wrap_moments.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test the estimation with multiple statistics."""

from matfree import stochtrace
from matfree.backend import func, linalg, np, prng, testing, tree_util
from matfree.backend import func, linalg, np, prng, testing, tree


def test_yields_correct_tree_structure():
Expand Down Expand Up @@ -32,7 +32,7 @@ def fun(x):
expected = {
"params": {"moment_1st": irrelevant_value, "moment_2nd": irrelevant_value}
}
assert tree_util.tree_structure(received) == tree_util.tree_structure(expected)
assert tree.tree_structure(received) == tree.tree_structure(expected)


@testing.fixture(name="key")
Expand Down
8 changes: 4 additions & 4 deletions tutorials/2_pytree_logdeterminants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def testfunc(x):

f0, jvp = jax.linearize(testfunc, x0)
_f0, vjp = jax.vjp(testfunc, x0)
print(jax.tree_util.tree_map(jnp.shape, f0))
print(jax.tree_util.tree_map(jnp.shape, jvp(x0)))
print(jax.tree_util.tree_map(jnp.shape, vjp(f0)))
print(jax.tree.map(jnp.shape, f0))
print(jax.tree.map(jnp.shape, jvp(x0)))
print(jax.tree.map(jnp.shape, vjp(f0)))


# Use the same API as if the matrix-vector product were array-valued.
Expand All @@ -46,7 +46,7 @@ def fun(fx, /):
r"""Matrix-vector product with $J J^\top + \alpha I$."""
vjp_eval = vjp(fx)
matvec_eval = jvp(*vjp_eval)
return jax.tree_util.tree_map(lambda x, y: x + alpha * y, matvec_eval, fx)
return jax.tree.map(lambda x, y: x + alpha * y, matvec_eval, fx)

return fun

Expand Down
Loading