diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index f70a81be2..4acf84708 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -17,6 +17,7 @@ New Features: ^^^^^^^^^^^^^ - Added ``repeat_action_probability`` argument in ``AtariWrapper``. - Only use ``NoopResetEnv`` and ``MaxAndSkipEnv`` when needed in ``AtariWrapper`` +- Added ``next_observations`` and ``has_next_observation`` fields to ``RolloutBufferSamples`` (@eohomegrownapps) `SB3-Contrib`_ ^^^^^^^^^^^^^^ @@ -1227,4 +1228,4 @@ And all the contributors: @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong -@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError +@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @eohomegrownapps diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 273dba9e0..a2d64d2ab 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -459,6 +459,11 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSample for tensor in _tensor_names: self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + + is_terminal = np.roll(self.episode_starts, -1, axis=0) + is_terminal[-1] = np.ones_like(is_terminal[-1]) + self.has_next_observation = 1.0 - self.swap_and_flatten(is_terminal) + self.generator_ready = True # Return everything, don't create minibatches @@ -475,8 +480,11 @@ def _get_samples( batch_inds: np.ndarray, env: Optional[VecNormalize] = None, ) -> RolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME + n = self.observations.shape[0] data = ( self.observations[batch_inds], + self.observations[(batch_inds + self.n_envs) % n], + self.has_next_observation[batch_inds].flatten(), self.actions[batch_inds], self.values[batch_inds].flatten(), self.log_probs[batch_inds].flatten(), @@ -765,6 +773,11 @@ def get( for tensor in _tensor_names: self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + + is_terminal = np.roll(self.episode_starts, -1, axis=0) + is_terminal[-1] = np.ones_like(is_terminal[-1]) + self.has_next_observation = 1.0 - self.swap_and_flatten(is_terminal) + self.generator_ready = True # Return everything, don't create minibatches @@ -781,8 +794,13 @@ def _get_samples( batch_inds: np.ndarray, env: Optional[VecNormalize] = None, ) -> DictRolloutBufferSamples: # type: ignore[signature-mismatch] #FIXME + n = self.actions.shape[0] return DictRolloutBufferSamples( observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, + next_observations={ + key: self.to_torch(obs[(batch_inds + self.n_envs) % n]) for (key, obs) in self.observations.items() + }, + has_next_observation=self.to_torch(self.has_next_observation[batch_inds].flatten()), actions=self.to_torch(self.actions[batch_inds]), old_values=self.to_torch(self.values[batch_inds].flatten()), old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 7227667a1..a7e8281f7 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -29,6 +29,8 @@ class RolloutBufferSamples(NamedTuple): observations: th.Tensor + next_observations: th.Tensor + has_next_observation: th.Tensor actions: th.Tensor old_values: th.Tensor old_log_prob: th.Tensor @@ -38,6 +40,8 @@ class RolloutBufferSamples(NamedTuple): class DictRolloutBufferSamples(NamedTuple): observations: TensorDict + next_observations: TensorDict + has_next_observation: th.Tensor actions: th.Tensor old_values: th.Tensor old_log_prob: th.Tensor diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 0e028e670..8d9ac2b9f 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -6,7 +6,7 @@ from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer from stable_baselines3.common.env_util import make_vec_env -from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples +from stable_baselines3.common.type_aliases import DictReplayBufferSamples, DictRolloutBufferSamples, ReplayBufferSamples from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import VecNormalize @@ -139,3 +139,32 @@ def test_device_buffer(replay_buffer_cls, device): assert value[key].device.type == desired_device elif isinstance(value, th.Tensor): assert value.device.type == desired_device + + +@pytest.mark.parametrize("rollout_buffer_cls", [RolloutBuffer, DictRolloutBuffer]) +def test_next_observations(rollout_buffer_cls): + env = {RolloutBuffer: DummyEnv, DictRolloutBuffer: DummyDictEnv}[rollout_buffer_cls] + env = make_vec_env(env) + + buffer = rollout_buffer_cls(100, env.observation_space, env.action_space, device="cpu") + + obs = env.reset() + for _ in range(100): + action = env.action_space.sample() + next_obs, reward, done, info = env.step(action) + values, log_prob = th.zeros(1), th.ones(1) + if isinstance(obs, dict): + buffer.add(obs, action, reward, (obs["observation"] == 1.0), values, log_prob) + else: + buffer.add(obs, action, reward, (obs == 1.0), values, log_prob) + obs = next_obs + + data = buffer.get(50) + for dp in data: + if isinstance(dp, DictRolloutBufferSamples): + for k in dp.observations.keys(): + assert th.equal((dp.observations[k] % 5) + 1, dp.next_observations[k]) + assert th.equal(th.flatten(dp.observations[k] != 5), dp.has_next_observation) + else: + assert th.equal((dp.observations % 5) + 1.0, dp.next_observations) + assert th.equal(th.flatten(dp.observations != 5), dp.has_next_observation)