Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prioritized replay buffer #175

Merged
merged 10 commits into from
Apr 3, 2024
2 changes: 1 addition & 1 deletion src/gfn/containers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .replay_buffer import ReplayBuffer
from .replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from .trajectories import Trajectories
from .transitions import Transitions
135 changes: 134 additions & 1 deletion src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class ReplayBuffer:

Attributes:
env: the Environment instance.
loss_fn: the Loss instance
capacity: the size of the buffer.
training_objects: the buffer of objects used for training.
terminating_states: a States class representation of $s_f$.
Expand Down Expand Up @@ -105,3 +104,137 @@ def load(self, directory: str):
self._index = len(self.training_objects)
if self.terminating_states is not None:
self.terminating_states.load(os.path.join(directory, "terminating_states"))


class PrioritizedReplayBuffer(ReplayBuffer):
"""A replay buffer of trajectories or transitions.

Attributes:
env: the Environment instance.
capacity: the size of the buffer.
training_objects: the buffer of objects used for training.
terminating_states: a States class representation of $s_f$.
objects_type: the type of buffer (transitions, trajectories, or states).
cutoff_distance: threshold used to determine if new last_states are different
enough from those already contained in the buffer.
p_norm_distance: p-norm distance value to pass to torch.cdist, for the
determination of novel states.
"""
def __init__(
self,
env: Env,
objects_type: Literal["transitions", "trajectories", "states"],
capacity: int = 1000,
cutoff_distance: float = 0.,
p_norm_distance: float = 1.,
):
"""Instantiates a prioritized replay buffer.
Args:
env: the Environment instance.
loss_fn: the Loss instance.
capacity: the size of the buffer.
objects_type: the type of buffer (transitions, trajectories, or states).
cutoff_distance: threshold used to determine if new last_states are
different enough from those already contained in the buffer.
p_norm_distance: p-norm distance value to pass to torch.cdist, for the
determination of novel states.
"""
super().__init__(env, objects_type, capacity)
self.cutoff_distance = cutoff_distance
self.p_norm_distance = p_norm_distance

def _add_objs(self, training_objects: Transitions | Trajectories | tuple[States]):
"""Adds a training object to the buffer."""
# Adds the objects to the buffer.
self.training_objects.extend(training_objects)

# Sort elements by logreward, capping the size at the defined capacity.
ix = torch.argsort(self.training_objects.log_rewards)
self.training_objects = self.training_objects[ix]
self.training_objects = self.training_objects[-self.capacity :]

# Add the terminating states to the buffer.
if self.terminating_states is not None:
assert terminating_states is not None
self.terminating_states.extend(terminating_states)

# Sort terminating states by logreward as well.
self.terminating_states = self.terminating_states[ix]
self.terminating_states = self.terminating_states[-self.capacity :]

def add(self, training_objects: Transitions | Trajectories | tuple[States]):
"""Adds a training object to the buffer."""
terminating_states = None
if isinstance(training_objects, tuple):
assert self.objects_type == "states" and self.terminating_states is not None
training_objects, terminating_states = training_objects

to_add = len(training_objects)

self._is_full |= self._index + to_add >= self.capacity
self._index = (self._index + to_add) % self.capacity
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if we don't add all the training, we would have increased self._index by more than needed.

Acutally, do we need self._index at all in ReplayBuffers?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't clear to me what this is used for actually. We can chat about it on our meeting.


# The buffer isn't full yet.
if len(self.training_objects) < self.capacity:
self._add_objs(training_objects)

# Our buffer is full and we will prioritize diverse, high reward additions.
else:
# Sort the incoming elements by their logrewards.
ix = torch.argsort(training_objects._log_rewards, descending=True)
training_objects = training_objects[ix]

# Filter all batch logrewards lower than the smallest logreward in buffer.
min_reward_in_buffer = self.training_objects.log_rewards.min()
idx_bigger_rewards = training_objects.log_rewards > min_reward_in_buffer
training_objects = training_objects[idx_bigger_rewards]

# Compute all pairwise distances between the batch and the buffer.
curr_dim = training_objects.last_states.batch_shape[0]
buffer_dim = self.training_objects.last_states.batch_shape[0]

# TODO: Concatenate input with final state for conditional GFN.
# if self.is_conditional:
# batch = torch.cat(
# [dict_curr_batch["input"], dict_curr_batch["final_state"]],
# dim=-1,
# )
# buffer = torch.cat(
# [self.storage["input"], self.storage["final_state"]],
# dim=-1,
# )
batch = training_objects.last_states.tensor.float()
buffer = self.training_objects.last_states.tensor.float()

# Filter the batch for diverse final_states with high reward.
batch_batch_dist = torch.cdist(
batch.view(curr_dim, -1).unsqueeze(0),
batch.view(curr_dim, -1).unsqueeze(0),
p=self.p_norm_distance,
).squeeze(0)

r, w = torch.triu_indices(*batch_batch_dist.shape) # Remove upper diag.
batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max
batch_batch_dist = batch_batch_dist.min(-1)[0]

# Filter the batch for diverse final_states w.r.t the buffer.
batch_buffer_dist = (
torch.cdist(
batch.view(curr_dim, -1).unsqueeze(0),
buffer.view(buffer_dim, -1).unsqueeze(0),
p=self.p_norm_distance,
)
.squeeze(0)
.min(-1)[0]
)

# Remove non-diverse examples according to the above distances.
idx_batch_batch = batch_batch_dist > self.cutoff_distance
idx_batch_buffer = batch_buffer_dist > self.cutoff_distance
idx_diverse = idx_batch_batch & idx_batch_buffer

training_objects = training_objects[idx_diverse]

# If any training object remain after filtering, add them.
if len(training_objects):
self._add_objs(training_objects)
18 changes: 15 additions & 3 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,22 +232,34 @@ def extend(self, other: Trajectories) -> None:
self.states.extend(other.states)
self.when_is_done = torch.cat((self.when_is_done, other.when_is_done), dim=0)

# For log_probs, we first need to make the first dimensions of self.log_probs and other.log_probs equal
# (i.e. the number of steps in the trajectories), and then concatenate them
# For log_probs, we first need to make the first dimensions of self.log_probs
# and other.log_probs equal (i.e. the number of steps in the trajectories), and
# then concatenate them.
new_max_length = max(self.log_probs.shape[0], other.log_probs.shape[0])
self.log_probs = self.extend_log_probs(self.log_probs, new_max_length)
other.log_probs = self.extend_log_probs(other.log_probs, new_max_length)

self.log_probs = torch.cat((self.log_probs, other.log_probs), dim=1)

# Concatenate log_rewards of the trajectories.
if self._log_rewards is not None and other._log_rewards is not None:
self._log_rewards = torch.cat(
(self._log_rewards, other._log_rewards),
dim=0,
)
# If the trajectories object does not yet have `log_rewards` assigned but the
# external trajectory has log_rewards, simply assign them over.
elif self._log_rewards is None and other._log_rewards is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?
This is actually dangerous and can easily lead to undesired behavior.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had a situation in the buffer where the empty initialized trajectory had _log_rewards = None, so any call to .extend() did not update the _log_rewards - we can handle this a few ways but, as is, thankfully tests are passing.

self._log_rewards = other._log_rewards
else:
self._log_rewards = None

# Ensure log_probs/rewards are the correct dimensions. TODO: Remove?
if self.log_probs.numel() > 0:
assert self.log_probs.shape == self.actions.batch_shape

if self.log_rewards is not None:
assert len(self.log_rewards) == self.actions.batch_shape[-1]

Comment on lines +257 to +263
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes good to check -- but maybe ideal that we don't have to check -- this adds overhead.

# Either set, or append, estimator outputs if they exist in the submitted
# trajectory.
if self.estimator_outputs is None and isinstance(
Expand Down
20 changes: 15 additions & 5 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import wandb
from tqdm import tqdm, trange

from gfn.containers import ReplayBuffer
from gfn.containers import ReplayBuffer, PrioritizedReplayBuffer
from gfn.gflownet import (
DBGFlowNet,
FMGFlowNet,
Expand Down Expand Up @@ -185,12 +185,17 @@ def main(args): # noqa: C901
objects_type = "states"
else:
raise NotImplementedError(f"Unknown loss: {args.loss}")
replay_buffer = ReplayBuffer(
env, objects_type=objects_type, capacity=args.replay_buffer_size
)

# 3. Create the optimizer
if args.replay_buffer_prioritized:
replay_buffer = PrioritizedReplayBuffer(
env, objects_type=objects_type, capacity=args.replay_buffer_size
)
else:
replay_buffer = ReplayBuffer(
env, objects_type=objects_type, capacity=args.replay_buffer_size
)

# 3. Create the optimizer
# Policy parameters have their own LR.
params = [
{
Expand Down Expand Up @@ -292,6 +297,11 @@ def main(args): # noqa: C901
default=0,
help="If zero, no replay buffer is used. Otherwise, the replay buffer is used.",
)
parser.add_argument(
"--replay_buffer_prioritized",
action="store_true",
help="If set and replay_buffer_size > 0, use a prioritized replay buffer.",
)

parser.add_argument(
"--loss",
Expand Down
Loading