Skip to content

Commit

Permalink
Add edit_tree
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Sep 10, 2024
1 parent 1efee84 commit 8a4f8d6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ._src.shims import custom_jvp, custom_jvp_method, custom_vjp, custom_vjp_method, hessian, jit
from ._src.testing import (assert_tree_allclose, get_relative_test_string, get_test_string,
tree_allclose)
from ._src.tree_tools import dynamic_tree_all
from ._src.tree_tools import dynamic_tree_all, NotEditableError, edit_tree

__all__ = [
'Array',
Expand All @@ -51,6 +51,7 @@
'JaxIntegralArray',
'JaxRealArray',
'KeyArray',
'NotEditableError',
'NumpyArray',
'NumpyBooleanArray',
'NumpyBooleanNumeric',
Expand Down Expand Up @@ -93,6 +94,7 @@
'divide_nonnegative',
'divide_where',
'dynamic_tree_all',
'edit_tree',
'fork_streams',
'get_relative_test_string',
'get_test_string',
Expand Down
22 changes: 22 additions & 0 deletions tjax/_src/tree_tools.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
from __future__ import annotations

from collections.abc import Generator
from contextlib import contextmanager
from typing import Any

import jax
import jax.numpy as jnp
from jax import tree

from .annotations import JaxBooleanArray


def dynamic_tree_all(tree: Any) -> JaxBooleanArray:
"""Like `jax.tree.all`, but can be used in dynamic code like jitted functions and loops."""
return jax.tree.reduce(jnp.logical_and, tree, jnp.asarray(True))


class NotEditableError(RuntimeError):
pass


@contextmanager
def edit_tree[T](model: T, /, editable_types: tuple[type[Any], ...]) -> Generator[T, None, None]:
flattened, pytree = tree.flatten(model)
model_copy = tree.unflatten(pytree, flattened)
yield model_copy

def verify(x: Any, y: Any) -> None:
if isinstance(x, editable_types):
return
if id(x) != id(y):
raise NotEditableError("Non-editable value changed.") # noqa: TRY003

tree.map(verify, model, model_copy, is_leaf=lambda x: isinstance(x, editable_types))

0 comments on commit 8a4f8d6

Please sign in to comment.