-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Prioritized replay buffer #175
Conversation
…ory contains log_rewards and the internal trajectory is None (this can happen with empty initalized trajectory)
src/gfn/containers/trajectories.py
Outdated
if self._log_rewards is not None and other._log_rewards is not None: | ||
self._log_rewards = torch.cat( | ||
(self._log_rewards, other._log_rewards), | ||
dim=0, | ||
) | ||
# If the trajectories object does not yet have `log_rewards` assigned but the | ||
# external trajectory has log_rewards, simply assign them over. | ||
elif self._log_rewards is None and other._log_rewards is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this needed?
This is actually dangerous and can easily lead to undesired behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We had a situation in the buffer where the empty initialized trajectory had _log_rewards = None
, so any call to .extend()
did not update the _log_rewards
- we can handle this a few ways but, as is, thankfully tests are passing.
# Ensure log_probs/rewards are the correct dimensions. TODO: Remove? | ||
if self.log_probs.numel() > 0: | ||
assert self.log_probs.shape == self.actions.batch_shape | ||
|
||
if self.log_rewards is not None: | ||
assert len(self.log_rewards) == self.actions.batch_shape[-1] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fair
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes good to check -- but maybe ideal that we don't have to check -- this adds overhead.
src/gfn/containers/replay_buffer.py
Outdated
to_add = len(training_objects) | ||
|
||
self._is_full |= self._index + to_add >= self.capacity | ||
self._index = (self._index + to_add) % self.capacity |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But if we don't add all the training, we would have increased self._index by more than needed.
Acutally, do we need self._index
at all in ReplayBuffers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It isn't clear to me what this is used for actually. We can chat about it on our meeting.
…extend always works, which means we no longer need the extra condition (but we are leaving in the sanity check for now).
@saleml Just need your official sign off on this before I can merge :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Great PR
I've adde a prioritized replay buffer. This:
In general, uniqueness is defined as distances between the candidate batch and buffer states using a
p_norm
-- this is configurable by the user. The default settings use a p-norm of 1 and a distance threshold of 0, i.e., the added states should not be identical to any state already in the buffer.Important you can test this using
Note: currently, the standard buffer outperforms the prioritized buffer using these default settings!
'loss': 9.107343066716567e-05, 'states_visited': 998416, 'l1_dist': 0.00023296871222555637, 'logZ_diff': 0.001130819320678711
'loss': 0.0003514138516038656, 'states_visited': 998416, 'l1_dist': 0.00017267849761992693, 'logZ_diff': 0.0020639896392822266
In the debugger, I could determine that no samples were ever added to the buffer after it was originally filled, because the states were not found to be unique. I.e., in
replay_buffer.py
the following logic always hadidx_batch_buffer
as completely full ofFalse
:@saleml I'd be curious to get your opinion on this. Perhaps we can tweak the implementation of the prioritized replay buffer, or perhaps this should be expected behaviour for this relatively simple example. I am not sure.