Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prioritized replay buffer #175

Merged
merged 10 commits into from
Apr 3, 2024
Merged

Prioritized replay buffer #175

merged 10 commits into from
Apr 3, 2024

Conversation

josephdviviano
Copy link
Collaborator

I've adde a prioritized replay buffer. This:

  • Only adds examples if their reward is larger than the min reward found in the buffer.
  • Only adds examples if they are unique (by default).

In general, uniqueness is defined as distances between the candidate batch and buffer states using a p_norm -- this is configurable by the user. The default settings use a p-norm of 1 and a distance threshold of 0, i.e., the added states should not be identical to any state already in the buffer.

Important you can test this using

tutorials/examples/train_hypergrid.py --replay_buffer_size 1000 --replay_buffer_prioritized

Note: currently, the standard buffer outperforms the prioritized buffer using these default settings!

  • Standard: 'loss': 9.107343066716567e-05, 'states_visited': 998416, 'l1_dist': 0.00023296871222555637, 'logZ_diff': 0.001130819320678711
  • Prioritized: 'loss': 0.0003514138516038656, 'states_visited': 998416, 'l1_dist': 0.00017267849761992693, 'logZ_diff': 0.0020639896392822266

In the debugger, I could determine that no samples were ever added to the buffer after it was originally filled, because the states were not found to be unique. I.e., in replay_buffer.py the following logic always had idx_batch_buffer as completely full of False:

    # Remove non-diverse examples according to the above distances.
    idx_batch_batch = batch_batch_dist > self.cutoff_distance
    idx_batch_buffer = batch_buffer_dist > self.cutoff_distance
    idx_diverse = idx_batch_batch & idx_batch_buffer

@saleml I'd be curious to get your opinion on this. Perhaps we can tweak the implementation of the prioritized replay buffer, or perhaps this should be expected behaviour for this relatively simple example. I am not sure.

@josephdviviano josephdviviano requested a review from saleml March 30, 2024 23:06
@josephdviviano josephdviviano self-assigned this Mar 30, 2024
if self._log_rewards is not None and other._log_rewards is not None:
self._log_rewards = torch.cat(
(self._log_rewards, other._log_rewards),
dim=0,
)
# If the trajectories object does not yet have `log_rewards` assigned but the
# external trajectory has log_rewards, simply assign them over.
elif self._log_rewards is None and other._log_rewards is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?
This is actually dangerous and can easily lead to undesired behavior.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had a situation in the buffer where the empty initialized trajectory had _log_rewards = None, so any call to .extend() did not update the _log_rewards - we can handle this a few ways but, as is, thankfully tests are passing.

Comment on lines +256 to +262
# Ensure log_probs/rewards are the correct dimensions. TODO: Remove?
if self.log_probs.numel() > 0:
assert self.log_probs.shape == self.actions.batch_shape

if self.log_rewards is not None:
assert len(self.log_rewards) == self.actions.batch_shape[-1]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes good to check -- but maybe ideal that we don't have to check -- this adds overhead.

to_add = len(training_objects)

self._is_full |= self._index + to_add >= self.capacity
self._index = (self._index + to_add) % self.capacity
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if we don't add all the training, we would have increased self._index by more than needed.

Acutally, do we need self._index at all in ReplayBuffers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't clear to me what this is used for actually. We can chat about it on our meeting.

Base automatically changed from fix_off_policy to master April 2, 2024 14:57
@josephdviviano
Copy link
Collaborator Author

@saleml Just need your official sign off on this before I can merge :)

Copy link
Collaborator

@saleml saleml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Great PR

@saleml saleml merged commit 4387e5b into master Apr 3, 2024
3 checks passed
@josephdviviano josephdviviano deleted the prioritized_replay_buffer branch April 5, 2024 18:12
@josephdviviano josephdviviano restored the prioritized_replay_buffer branch November 15, 2024 00:49
@josephdviviano josephdviviano deleted the prioritized_replay_buffer branch November 15, 2024 00:49
@josephdviviano josephdviviano restored the prioritized_replay_buffer branch November 15, 2024 00:51
@josephdviviano josephdviviano deleted the prioritized_replay_buffer branch November 15, 2024 00:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants