Skip to content

Commit

Permalink
added a prioritized replay buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Mar 30, 2024
1 parent 61f7fd2 commit 75e3198
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/gfn/containers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .replay_buffer import ReplayBuffer
from .replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from .trajectories import Trajectories
from .transitions import Transitions
135 changes: 134 additions & 1 deletion src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class ReplayBuffer:
Attributes:
env: the Environment instance.
loss_fn: the Loss instance
capacity: the size of the buffer.
training_objects: the buffer of objects used for training.
terminating_states: a States class representation of $s_f$.
Expand Down Expand Up @@ -105,3 +104,137 @@ def load(self, directory: str):
self._index = len(self.training_objects)
if self.terminating_states is not None:
self.terminating_states.load(os.path.join(directory, "terminating_states"))


class PrioritizedReplayBuffer(ReplayBuffer):
"""A replay buffer of trajectories or transitions.
Attributes:
env: the Environment instance.
capacity: the size of the buffer.
training_objects: the buffer of objects used for training.
terminating_states: a States class representation of $s_f$.
objects_type: the type of buffer (transitions, trajectories, or states).
cutoff_distance: threshold used to determine if new last_states are different
enough from those already contained in the buffer.
p_norm_distance: p-norm distance value to pass to torch.cdist, for the
determination of novel states.
"""
def __init__(
self,
env: Env,
objects_type: Literal["transitions", "trajectories", "states"],
capacity: int = 1000,
cutoff_distance: float = 0.,
p_norm_distance: float = 1.,
):
"""Instantiates a prioritized replay buffer.
Args:
env: the Environment instance.
loss_fn: the Loss instance.
capacity: the size of the buffer.
objects_type: the type of buffer (transitions, trajectories, or states).
cutoff_distance: threshold used to determine if new last_states are
different enough from those already contained in the buffer.
p_norm_distance: p-norm distance value to pass to torch.cdist, for the
determination of novel states.
"""
super().__init__(env, objects_type, capacity)
self.cutoff_distance = cutoff_distance
self.p_norm_distance = p_norm_distance

def _add_objs(self, training_objects: Transitions | Trajectories | tuple[States]):
"""Adds a training object to the buffer."""
# Adds the objects to the buffer.
self.training_objects.extend(training_objects)

# Sort elements by logreward, capping the size at the defined capacity.
ix = torch.argsort(self.training_objects.log_rewards)
self.training_objects = self.training_objects[ix]
self.training_objects = self.training_objects[-self.capacity :]

# Add the terminating states to the buffer.
if self.terminating_states is not None:
assert terminating_states is not None
self.terminating_states.extend(terminating_states)

# Sort terminating states by logreward as well.
self.terminating_states = self.terminating_states[ix]
self.terminating_states = self.terminating_states[-self.capacity :]

def add(self, training_objects: Transitions | Trajectories | tuple[States]):
"""Adds a training object to the buffer."""
terminating_states = None
if isinstance(training_objects, tuple):
assert self.objects_type == "states" and self.terminating_states is not None
training_objects, terminating_states = training_objects

to_add = len(training_objects)

self._is_full |= self._index + to_add >= self.capacity
self._index = (self._index + to_add) % self.capacity

# The buffer isn't full yet.
if len(self.training_objects) < self.capacity:
self._add_objs(training_objects)

# Our buffer is full and we will prioritize diverse, high reward additions.
else:
# Sort the incoming elements by their logrewards.
ix = torch.argsort(training_objects._log_rewards, descending=True)
training_objects = training_objects[ix]

# Filter all batch logrewards lower than the smallest logreward in buffer.
min_reward_in_buffer = self.training_objects.log_rewards.min()
idx_bigger_rewards = training_objects.log_rewards > min_reward_in_buffer
training_objects = training_objects[idx_bigger_rewards]

# Compute all pairwise distances between the batch and the buffer.
curr_dim = training_objects.last_states.batch_shape[0]
buffer_dim = self.training_objects.last_states.batch_shape[0]

# TODO: Concatenate input with final state for conditional GFN.
# if self.is_conditional:
# batch = torch.cat(
# [dict_curr_batch["input"], dict_curr_batch["final_state"]],
# dim=-1,
# )
# buffer = torch.cat(
# [self.storage["input"], self.storage["final_state"]],
# dim=-1,
# )
batch = training_objects.last_states.tensor.float()
buffer = self.training_objects.last_states.tensor.float()

# Filter the batch for diverse final_states with high reward.
batch_batch_dist = torch.cdist(
batch.view(curr_dim, -1).unsqueeze(0),
batch.view(curr_dim, -1).unsqueeze(0),
p=self.p_norm_distance,
).squeeze(0)

r, w = torch.triu_indices(*batch_batch_dist.shape) # Remove upper diag.
batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max
batch_batch_dist = batch_batch_dist.min(-1)[0]

# Filter the batch for diverse final_states w.r.t the buffer.
batch_buffer_dist = (
torch.cdist(
batch.view(curr_dim, -1).unsqueeze(0),
buffer.view(buffer_dim, -1).unsqueeze(0),
p=self.p_norm_distance,
)
.squeeze(0)
.min(-1)[0]
)

# Remove non-diverse examples according to the above distances.
idx_batch_batch = batch_batch_dist > self.cutoff_distance
idx_batch_buffer = batch_buffer_dist > self.cutoff_distance
idx_diverse = idx_batch_batch & idx_batch_buffer

training_objects = training_objects[idx_diverse]

# If any training object remain after filtering, add them.
if len(training_objects):
self._add_objs(training_objects)

0 comments on commit 75e3198

Please sign in to comment.