Skip to content

Commit

Permalink
Update for Flax 0.9
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Sep 9, 2024
1 parent b4f5f75 commit 7230a99
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 168 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ repository = 'https://github.com/NeilGirdhar/tjax'
[tool.uv]
dev-dependencies = [
'chex >= 0.1.3',
'flax >= 0.8.4',
'flax >= 0.9.0',
'isort >= 5.11',
'jupyter >= 1',
'mypy >= 1.7',
Expand Down
6 changes: 3 additions & 3 deletions tests/test_flax_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def graph() -> nx.DiGraph[Any]:


def test_rebuild(graph: nx.DiGraph[Any]) -> None:
graph_def, state, _ = nnx.graph.flatten(graph)
rebuilt_graph, _ = nnx.graph.unflatten(graph_def, state)
graph_def, state = nnx.graph.flatten(graph)
rebuilt_graph = nnx.graph.unflatten(graph_def, state)
assert nx.utils.graphs_equal(graph, rebuilt_graph)


Expand All @@ -46,7 +46,7 @@ def __init__(self) -> None:


def test_flatten(graph: nx.DiGraph[Any]) -> None:
_, state, _ = nnx.graph.flatten(graph)
_, state = nnx.graph.flatten(graph)
substate = state[GraphEdgeKey('a', 'b')]
assert isinstance(substate, nnx.State)
variable = substate['x']
Expand Down
Loading

0 comments on commit 7230a99

Please sign in to comment.