-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prioritized replay buffer #175
Changes from 4 commits
e087f41
61f7fd2
75e3198
7449a92
184c5f5
00cab17
a670356
213653c
04e50e9
5472055
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .replay_buffer import ReplayBuffer | ||
from .replay_buffer import ReplayBuffer, PrioritizedReplayBuffer | ||
from .trajectories import Trajectories | ||
from .transitions import Transitions |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -232,22 +232,34 @@ def extend(self, other: Trajectories) -> None: | |
self.states.extend(other.states) | ||
self.when_is_done = torch.cat((self.when_is_done, other.when_is_done), dim=0) | ||
|
||
# For log_probs, we first need to make the first dimensions of self.log_probs and other.log_probs equal | ||
# (i.e. the number of steps in the trajectories), and then concatenate them | ||
# For log_probs, we first need to make the first dimensions of self.log_probs | ||
# and other.log_probs equal (i.e. the number of steps in the trajectories), and | ||
# then concatenate them. | ||
new_max_length = max(self.log_probs.shape[0], other.log_probs.shape[0]) | ||
self.log_probs = self.extend_log_probs(self.log_probs, new_max_length) | ||
other.log_probs = self.extend_log_probs(other.log_probs, new_max_length) | ||
|
||
self.log_probs = torch.cat((self.log_probs, other.log_probs), dim=1) | ||
|
||
# Concatenate log_rewards of the trajectories. | ||
if self._log_rewards is not None and other._log_rewards is not None: | ||
self._log_rewards = torch.cat( | ||
(self._log_rewards, other._log_rewards), | ||
dim=0, | ||
) | ||
# If the trajectories object does not yet have `log_rewards` assigned but the | ||
# external trajectory has log_rewards, simply assign them over. | ||
elif self._log_rewards is None and other._log_rewards is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We had a situation in the buffer where the empty initialized trajectory had |
||
self._log_rewards = other._log_rewards | ||
else: | ||
self._log_rewards = None | ||
|
||
# Ensure log_probs/rewards are the correct dimensions. TODO: Remove? | ||
if self.log_probs.numel() > 0: | ||
assert self.log_probs.shape == self.actions.batch_shape | ||
|
||
if self.log_rewards is not None: | ||
assert len(self.log_rewards) == self.actions.batch_shape[-1] | ||
|
||
Comment on lines
+257
to
+263
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fair There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes good to check -- but maybe ideal that we don't have to check -- this adds overhead. |
||
# Either set, or append, estimator outputs if they exist in the submitted | ||
# trajectory. | ||
if self.estimator_outputs is None and isinstance( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But if we don't add all the training, we would have increased self._index by more than needed.
Acutally, do we need
self._index
at all in ReplayBuffers?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It isn't clear to me what this is used for actually. We can chat about it on our meeting.