Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prioritized replay buffer #175

Merged
merged 10 commits into from
Apr 3, 2024
2 changes: 1 addition & 1 deletion src/gfn/containers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .replay_buffer import ReplayBuffer
from .replay_buffer import ReplayBuffer, PrioritizedReplayBuffer
from .trajectories import Trajectories
from .transitions import Transitions
143 changes: 137 additions & 6 deletions src/gfn/containers/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ class ReplayBuffer:
Attributes:
env: the Environment instance.
loss_fn: the Loss instance
capacity: the size of the buffer.
training_objects: the buffer of objects used for training.
terminating_states: a States class representation of $s_f$.
Expand Down Expand Up @@ -56,13 +55,12 @@ def __init__(
raise ValueError(f"Unknown objects_type: {objects_type}")

self._is_full = False
self._index = 0

def __repr__(self):
return f"ReplayBuffer(capacity={self.capacity}, containing {len(self)} {self.objects_type})"

def __len__(self):
return self.capacity if self._is_full else self._index
return len(self.training_objects)

def add(self, training_objects: Transitions | Trajectories | tuple[States]):
"""Adds a training object to the buffer."""
Expand All @@ -73,8 +71,7 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]):

to_add = len(training_objects)

self._is_full |= self._index + to_add >= self.capacity
self._index = (self._index + to_add) % self.capacity
self._is_full |= len(self) + to_add >= self.capacity

self.training_objects.extend(training_objects)
self.training_objects = self.training_objects[-self.capacity :]
Expand Down Expand Up @@ -102,6 +99,140 @@ def save(self, directory: str):
def load(self, directory: str):
"""Loads the buffer from disk."""
self.training_objects.load(os.path.join(directory, "training_objects"))
self._index = len(self.training_objects)
if self.terminating_states is not None:
self.terminating_states.load(os.path.join(directory, "terminating_states"))


class PrioritizedReplayBuffer(ReplayBuffer):
"""A replay buffer of trajectories or transitions.
Attributes:
env: the Environment instance.
capacity: the size of the buffer.
training_objects: the buffer of objects used for training.
terminating_states: a States class representation of $s_f$.
objects_type: the type of buffer (transitions, trajectories, or states).
cutoff_distance: threshold used to determine if new last_states are different
enough from those already contained in the buffer.
p_norm_distance: p-norm distance value to pass to torch.cdist, for the
determination of novel states.
"""
def __init__(
self,
env: Env,
objects_type: Literal["transitions", "trajectories", "states"],
capacity: int = 1000,
cutoff_distance: float = 0.,
p_norm_distance: float = 1.,
):
"""Instantiates a prioritized replay buffer.
Args:
env: the Environment instance.
loss_fn: the Loss instance.
capacity: the size of the buffer.
objects_type: the type of buffer (transitions, trajectories, or states).
cutoff_distance: threshold used to determine if new last_states are
different enough from those already contained in the buffer. If the
cutoff is negative, all diversity caclulations are skipped (since all
norms are >= 0).
p_norm_distance: p-norm distance value to pass to torch.cdist, for the
determination of novel states.
"""
super().__init__(env, objects_type, capacity)
self.cutoff_distance = cutoff_distance
self.p_norm_distance = p_norm_distance

def _add_objs(self, training_objects: Transitions | Trajectories | tuple[States]):
"""Adds a training object to the buffer."""
# Adds the objects to the buffer.
self.training_objects.extend(training_objects)

# Sort elements by logreward, capping the size at the defined capacity.
ix = torch.argsort(self.training_objects.log_rewards)
self.training_objects = self.training_objects[ix]
self.training_objects = self.training_objects[-self.capacity :]

# Add the terminating states to the buffer.
if self.terminating_states is not None:
assert terminating_states is not None
self.terminating_states.extend(terminating_states)

# Sort terminating states by logreward as well.
self.terminating_states = self.terminating_states[ix]
self.terminating_states = self.terminating_states[-self.capacity :]

def add(self, training_objects: Transitions | Trajectories | tuple[States]):
"""Adds a training object to the buffer."""
terminating_states = None
if isinstance(training_objects, tuple):
assert self.objects_type == "states" and self.terminating_states is not None
training_objects, terminating_states = training_objects

to_add = len(training_objects)
self._is_full |= len(self) + to_add >= self.capacity

# The buffer isn't full yet.
if len(self.training_objects) < self.capacity:
self._add_objs(training_objects)

# Our buffer is full and we will prioritize diverse, high reward additions.
else:
# Sort the incoming elements by their logrewards.
ix = torch.argsort(training_objects.log_rewards, descending=True)
training_objects = training_objects[ix]

# Filter all batch logrewards lower than the smallest logreward in buffer.
min_reward_in_buffer = self.training_objects.log_rewards.min()
idx_bigger_rewards = training_objects.log_rewards >= min_reward_in_buffer
training_objects = training_objects[idx_bigger_rewards]

# TODO: Concatenate input with final state for conditional GFN.
# if self.is_conditional:
# batch = torch.cat(
# [dict_curr_batch["input"], dict_curr_batch["final_state"]],
# dim=-1,
# )
# buffer = torch.cat(
# [self.storage["input"], self.storage["final_state"]],
# dim=-1,
# )

if self.cutoff_distance >= 0:
# Filter the batch for diverse final_states with high reward.
batch = training_objects.last_states.tensor.float()
batch_dim = training_objects.last_states.batch_shape[0]
batch_batch_dist = torch.cdist(
batch.view(batch_dim, -1).unsqueeze(0),
batch.view(batch_dim, -1).unsqueeze(0),
p=self.p_norm_distance,
).squeeze(0)

# Finds the min distance at each row, and removes rows below the cutoff.
r, w = torch.triu_indices(*batch_batch_dist.shape) # Remove upper diag.
batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max
batch_batch_dist = batch_batch_dist.min(-1)[0]
idx_batch_batch = batch_batch_dist > self.cutoff_distance
training_objects = training_objects[idx_batch_batch]

# Compute all pairwise distances between the remaining batch & buffer.
batch = training_objects.last_states.tensor.float()
buffer = self.training_objects.last_states.tensor.float()
batch_dim = training_objects.last_states.batch_shape[0]
buffer_dim = self.training_objects.last_states.batch_shape[0]
batch_buffer_dist = (
torch.cdist(
batch.view(batch_dim, -1).unsqueeze(0),
buffer.view(buffer_dim, -1).unsqueeze(0),
p=self.p_norm_distance,
)
.squeeze(0)
.min(-1)[0] # Min calculated over rows - the batch elements.
)

# Filter the batch for diverse final_states w.r.t the buffer.
idx_batch_buffer = batch_buffer_dist > self.cutoff_distance
training_objects = training_objects[idx_batch_buffer]

# If any training object remain after filtering, add them.
if len(training_objects):
self._add_objs(training_objects)
21 changes: 17 additions & 4 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ def __init__(
if when_is_done is not None
else torch.full(size=(0,), fill_value=-1, dtype=torch.long)
)
self._log_rewards = log_rewards
self._log_rewards = (
log_rewards
if log_rewards is not None
else torch.full(size=(0,), fill_value=0, dtype=torch.float)
)
self.log_probs = (
log_probs
if log_probs is not None
Expand Down Expand Up @@ -232,22 +236,31 @@ def extend(self, other: Trajectories) -> None:
self.states.extend(other.states)
self.when_is_done = torch.cat((self.when_is_done, other.when_is_done), dim=0)

# For log_probs, we first need to make the first dimensions of self.log_probs and other.log_probs equal
# (i.e. the number of steps in the trajectories), and then concatenate them
# For log_probs, we first need to make the first dimensions of self.log_probs
# and other.log_probs equal (i.e. the number of steps in the trajectories), and
# then concatenate them.
new_max_length = max(self.log_probs.shape[0], other.log_probs.shape[0])
self.log_probs = self.extend_log_probs(self.log_probs, new_max_length)
other.log_probs = self.extend_log_probs(other.log_probs, new_max_length)

self.log_probs = torch.cat((self.log_probs, other.log_probs), dim=1)

# Concatenate log_rewards of the trajectories.
if self._log_rewards is not None and other._log_rewards is not None:
self._log_rewards = torch.cat(
(self._log_rewards, other._log_rewards),
dim=0,
)
# Will not be None if object is initialized as empty.
else:
self._log_rewards = None

# Ensure log_probs/rewards are the correct dimensions. TODO: Remove?
if self.log_probs.numel() > 0:
assert self.log_probs.shape == self.actions.batch_shape

if self.log_rewards is not None:
assert len(self.log_rewards) == self.actions.batch_shape[-1]

Comment on lines +257 to +263
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fair

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes good to check -- but maybe ideal that we don't have to check -- this adds overhead.

# Either set, or append, estimator outputs if they exist in the submitted
# trajectory.
if self.estimator_outputs is None and isinstance(
Expand Down
5 changes: 4 additions & 1 deletion src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(
len(self.next_states.batch_shape) == 1
and self.states.batch_shape == self.next_states.batch_shape
)
self._log_rewards = log_rewards
self._log_rewards = log_rewards if log_rewards is not None else torch.zeros(0)
self.log_probs = log_probs if log_probs is not None else torch.zeros(0)

@property
Expand Down Expand Up @@ -208,10 +208,13 @@ def extend(self, other: Transitions) -> None:
self.actions.extend(other.actions)
self.is_done = torch.cat((self.is_done, other.is_done), dim=0)
self.next_states.extend(other.next_states)

# Concatenate log_rewards of the trajectories.
if self._log_rewards is not None and other._log_rewards is not None:
self._log_rewards = torch.cat(
(self._log_rewards, other._log_rewards), dim=0
)
# Will not be None if object is initialized as empty.
else:
self._log_rewards = None
self.log_probs = torch.cat((self.log_probs, other.log_probs), dim=0)
24 changes: 19 additions & 5 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import wandb
from tqdm import tqdm, trange

from gfn.containers import ReplayBuffer
from gfn.containers import ReplayBuffer, PrioritizedReplayBuffer
from gfn.gflownet import (
DBGFlowNet,
FMGFlowNet,
Expand Down Expand Up @@ -185,12 +185,21 @@ def main(args): # noqa: C901
objects_type = "states"
else:
raise NotImplementedError(f"Unknown loss: {args.loss}")
replay_buffer = ReplayBuffer(
env, objects_type=objects_type, capacity=args.replay_buffer_size
)

# 3. Create the optimizer
if args.replay_buffer_prioritized:
replay_buffer = PrioritizedReplayBuffer(
env,
objects_type=objects_type,
capacity=args.replay_buffer_size,
p_norm_distance=1, # Use L1-norm for diversity estimation.
cutoff_distance=0, # -1 turns off diversity-based filtering.
)
else:
replay_buffer = ReplayBuffer(
env, objects_type=objects_type, capacity=args.replay_buffer_size
)

# 3. Create the optimizer
# Policy parameters have their own LR.
params = [
{
Expand Down Expand Up @@ -292,6 +301,11 @@ def main(args): # noqa: C901
default=0,
help="If zero, no replay buffer is used. Otherwise, the replay buffer is used.",
)
parser.add_argument(
"--replay_buffer_prioritized",
action="store_true",
help="If set and replay_buffer_size > 0, use a prioritized replay buffer.",
)

parser.add_argument(
"--loss",
Expand Down
Loading