Skip to content

Commit

Permalink
fix: using flax.core.pop instead of variables.pop, better way to upda…
Browse files Browse the repository at this point in the history
…te batch_stats
  • Loading branch information
init-22 committed Dec 2, 2024
1 parent 86029a7 commit aca45a2
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from flax import jax_utils
from flax import linen as nn
from flax.core import pop
import jax
from jax import lax
import jax.numpy as jnp
Expand Down Expand Up @@ -79,8 +80,8 @@ def sync_batch_stats(
# In this case each device has its own version of the batch statistics and
# we average them.
avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
new_model_state = model_state.copy(
{'batch_stats': avg_fn(model_state['batch_stats'])})
new_model_state = model_state.copy() # Create a shallow copy
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
return new_model_state

def init_model_fn(
Expand Down Expand Up @@ -111,7 +112,7 @@ def init_model_fn(
input_shape = (1, 224, 224, 3)
variables = jax.jit(model.init)({'params': rng},
jnp.ones(input_shape, model.dtype))
model_state, params = variables.pop('params')
model_state, params = pop(variables, "params")
self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
model_state = jax_utils.replicate(model_state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from flax import jax_utils
from flax import linen as nn
from flax.core import pop
import jax
import jax.numpy as jnp

Expand All @@ -28,7 +29,7 @@ def initialized(self, key: spec.RandomState,
variables = jax.jit(
model.init)({'params': params_rng, 'dropout': dropout_rng},
jnp.ones(input_shape))
model_state, params = variables.pop('params')
model_state, params = pop(variables, "params")
return params, model_state

def init_model_fn(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, Iterator, Optional, Tuple

from flax import jax_utils
from flax.core import pop
import flax.linen as nn
import jax
from jax import lax
Expand Down Expand Up @@ -89,7 +90,7 @@ def init_model_fn(
variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng},
*fake_input_batch)

model_state, params = variables.pop('params')
model_state, params = pop(variables, "params")

self._param_shapes = param_utils.jax_param_shapes(params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
Expand Down Expand Up @@ -374,8 +375,8 @@ def sync_batch_stats(
# In this case each device has its own version of the batch statistics and
# we average them.
avg_fn = jax.pmap(lambda x: lax.pmean(x, 'x'), 'x')
new_model_state = model_state.copy(
{'batch_stats': avg_fn(model_state['batch_stats'])})
new_model_state = model_state.copy()
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
return new_model_state


Expand Down

0 comments on commit aca45a2

Please sign in to comment.