diff --git a/brax/envs/fast.py b/brax/envs/fast.py index b7728524..7541a598 100644 --- a/brax/envs/fast.py +++ b/brax/envs/fast.py @@ -74,9 +74,7 @@ def reset(self, rng: jax.Array) -> State: 'pixels/view_1': jp.zeros((4, 4, 3)), } - if self._obs_mode == ObservationMode.DICT_STATE: - obs = obs - elif self._obs_mode == ObservationMode.DICT_PIXELS: + if self._obs_mode == ObservationMode.DICT_PIXELS: obs = pixels elif self._obs_mode == ObservationMode.DICT_PIXELS_STATE: obs = {**obs, **pixels} @@ -104,9 +102,7 @@ def step(self, state: State, action: jax.Array) -> State: 'pixels/view_1': jp.zeros((4, 4, 3)), } - if self._obs_mode == ObservationMode.DICT_STATE: - obs = obs - elif self._obs_mode == ObservationMode.DICT_PIXELS: + if self._obs_mode == ObservationMode.DICT_PIXELS: obs = pixels elif self._obs_mode == ObservationMode.DICT_PIXELS_STATE: obs = {**obs, **pixels} diff --git a/brax/training/agents/ppo/train.py b/brax/training/agents/ppo/train.py index 4d712a47..af8c3bc8 100644 --- a/brax/training/agents/ppo/train.py +++ b/brax/training/agents/ppo/train.py @@ -89,6 +89,7 @@ def _random_translate_pixels( Returns: A dictionary of observations with translated pixels """ + obs = core.FrozenDict(obs) @jax.vmap def rt_all_views( diff --git a/brax/training/networks.py b/brax/training/networks.py index 448079d6..ad5e392a 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -370,6 +370,7 @@ def make_policy_network_vision( ) def apply(processor_params, policy_params, obs): + obs = core.FrozenDict(obs) if state_obs_key: state_obs = preprocess_observations_fn( obs[state_obs_key], normalizer_select(processor_params, state_obs_key)