diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index d8de214f5..8ab4adbb9 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -11,6 +11,7 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax from jax import lax import jax.numpy as jnp @@ -79,8 +80,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics and # we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() # Create a shallow copy + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state def init_model_fn( @@ -111,7 +112,7 @@ def init_model_fn( input_shape = (1, 224, 224, 3) variables = jax.jit(model.init)({'params': rng}, jnp.ones(input_shape, model.dtype)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) model_state = jax_utils.replicate(model_state) diff --git a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py index 2ad71ffd0..5f826d035 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/imagenet_jax/workload.py @@ -4,6 +4,7 @@ from flax import jax_utils from flax import linen as nn +from flax.core import pop import jax import jax.numpy as jnp @@ -28,7 +29,7 @@ def initialized(self, key: spec.RandomState, variables = jax.jit( model.init)({'params': params_rng, 'dropout': dropout_rng}, jnp.ones(input_shape)) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") return params, model_state def init_model_fn( diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index f4d1ab0f3..d805e8b17 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -3,6 +3,7 @@ from typing import Dict, Iterator, Optional, Tuple from flax import jax_utils +from flax.core import pop import flax.linen as nn import jax from jax import lax @@ -89,7 +90,7 @@ def init_model_fn( variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch) - model_state, params = variables.pop('params') + model_state, params = pop(variables, "params") self._param_shapes = param_utils.jax_param_shapes(params) self._param_types = param_utils.jax_param_types(self._param_shapes) @@ -374,8 +375,8 @@ def sync_batch_stats( # In this case each device has its own version of the batch statistics and # we average them. avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x') - new_model_state = model_state.copy( - {'batch_stats': avg_fn(model_state['batch_stats'])}) + new_model_state = model_state.copy() + new_model_state['batch_stats'] = avg_fn(model_state['batch_stats']) return new_model_state