Skip to content

Commit

Permalink
fix: unfreeze() in test_param_shapes expect FrozenDict also added fla…
Browse files Browse the repository at this point in the history
…x.core.pop instead of variables.pop
  • Loading branch information
init-22 committed Dec 2, 2024
1 parent 2618c5e commit 8c90625
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
7 changes: 4 additions & 3 deletions algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from flax import jax_utils
from flax import linen as nn
from flax.core import pop
import jax
from jax import lax
import jax.numpy as jnp
Expand Down Expand Up @@ -75,8 +76,8 @@ def sync_batch_stats(
# In this case each device has its own version of the batch statistics
# and we average them.
avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
new_model_state = model_state.copy(
{'batch_stats': avg_fn(model_state['batch_stats'])})
new_model_state = model_state.copy()
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
return new_model_state

def init_model_fn(
Expand All @@ -93,7 +94,7 @@ def init_model_fn(
input_shape = (1, 32, 32, 3)
variables = jax.jit(model.init)({'params': rng},
jnp.ones(input_shape, model.dtype))
model_state, params = variables.pop('params')
model_state, params = pop(variables, 'params')
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
model_state = jax_utils.replicate(model_state)
Expand Down
6 changes: 5 additions & 1 deletion tests/test_param_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax
import numpy as np
import pytest
from flax.core import FrozenDict

# isort: skip_file
# pylint:disable=line-too-long
Expand Down Expand Up @@ -51,8 +52,11 @@
def test_param_shapes(workload):
jax_workload, pytorch_workload = get_workload(workload)
# Compare number of parameter tensors of both models.
jax_workload_param_shapes = jax_workload.param_shapes
if isinstance(jax_workload_param_shapes, dict):
jax_workload_param_shapes = FrozenDict(jax_workload_param_shapes)
jax_param_shapes = jax.tree_util.tree_leaves(
jax_workload.param_shapes.unfreeze())
jax_workload_param_shapes.unfreeze())
pytorch_param_shapes = jax.tree_util.tree_leaves(
pytorch_workload.param_shapes)
if workload == 'wmt':
Expand Down

0 comments on commit 8c90625

Please sign in to comment.