diff --git a/src/gfn/containers/__init__.py b/src/gfn/containers/__init__.py index f3c3c9a5..0acaab06 100644 --- a/src/gfn/containers/__init__.py +++ b/src/gfn/containers/__init__.py @@ -1,3 +1,3 @@ -from .replay_buffer import ReplayBuffer +from .replay_buffer import ReplayBuffer, PrioritizedReplayBuffer from .trajectories import Trajectories from .transitions import Transitions diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index c1679d9a..ee24c882 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -18,7 +18,6 @@ class ReplayBuffer: Attributes: env: the Environment instance. - loss_fn: the Loss instance capacity: the size of the buffer. training_objects: the buffer of objects used for training. terminating_states: a States class representation of $s_f$. @@ -105,3 +104,137 @@ def load(self, directory: str): self._index = len(self.training_objects) if self.terminating_states is not None: self.terminating_states.load(os.path.join(directory, "terminating_states")) + + +class PrioritizedReplayBuffer(ReplayBuffer): + """A replay buffer of trajectories or transitions. + + Attributes: + env: the Environment instance. + capacity: the size of the buffer. + training_objects: the buffer of objects used for training. + terminating_states: a States class representation of $s_f$. + objects_type: the type of buffer (transitions, trajectories, or states). + cutoff_distance: threshold used to determine if new last_states are different + enough from those already contained in the buffer. + p_norm_distance: p-norm distance value to pass to torch.cdist, for the + determination of novel states. + """ + def __init__( + self, + env: Env, + objects_type: Literal["transitions", "trajectories", "states"], + capacity: int = 1000, + cutoff_distance: float = 0., + p_norm_distance: float = 1., + ): + """Instantiates a prioritized replay buffer. + Args: + env: the Environment instance. + loss_fn: the Loss instance. + capacity: the size of the buffer. + objects_type: the type of buffer (transitions, trajectories, or states). + cutoff_distance: threshold used to determine if new last_states are + different enough from those already contained in the buffer. + p_norm_distance: p-norm distance value to pass to torch.cdist, for the + determination of novel states. + """ + super().__init__(env, objects_type, capacity) + self.cutoff_distance = cutoff_distance + self.p_norm_distance = p_norm_distance + + def _add_objs(self, training_objects: Transitions | Trajectories | tuple[States]): + """Adds a training object to the buffer.""" + # Adds the objects to the buffer. + self.training_objects.extend(training_objects) + + # Sort elements by logreward, capping the size at the defined capacity. + ix = torch.argsort(self.training_objects.log_rewards) + self.training_objects = self.training_objects[ix] + self.training_objects = self.training_objects[-self.capacity :] + + # Add the terminating states to the buffer. + if self.terminating_states is not None: + assert terminating_states is not None + self.terminating_states.extend(terminating_states) + + # Sort terminating states by logreward as well. + self.terminating_states = self.terminating_states[ix] + self.terminating_states = self.terminating_states[-self.capacity :] + + def add(self, training_objects: Transitions | Trajectories | tuple[States]): + """Adds a training object to the buffer.""" + terminating_states = None + if isinstance(training_objects, tuple): + assert self.objects_type == "states" and self.terminating_states is not None + training_objects, terminating_states = training_objects + + to_add = len(training_objects) + + self._is_full |= self._index + to_add >= self.capacity + self._index = (self._index + to_add) % self.capacity + + # The buffer isn't full yet. + if len(self.training_objects) < self.capacity: + self._add_objs(training_objects) + + # Our buffer is full and we will prioritize diverse, high reward additions. + else: + # Sort the incoming elements by their logrewards. + ix = torch.argsort(training_objects._log_rewards, descending=True) + training_objects = training_objects[ix] + + # Filter all batch logrewards lower than the smallest logreward in buffer. + min_reward_in_buffer = self.training_objects.log_rewards.min() + idx_bigger_rewards = training_objects.log_rewards > min_reward_in_buffer + training_objects = training_objects[idx_bigger_rewards] + + # Compute all pairwise distances between the batch and the buffer. + curr_dim = training_objects.last_states.batch_shape[0] + buffer_dim = self.training_objects.last_states.batch_shape[0] + + # TODO: Concatenate input with final state for conditional GFN. + # if self.is_conditional: + # batch = torch.cat( + # [dict_curr_batch["input"], dict_curr_batch["final_state"]], + # dim=-1, + # ) + # buffer = torch.cat( + # [self.storage["input"], self.storage["final_state"]], + # dim=-1, + # ) + batch = training_objects.last_states.tensor.float() + buffer = self.training_objects.last_states.tensor.float() + + # Filter the batch for diverse final_states with high reward. + batch_batch_dist = torch.cdist( + batch.view(curr_dim, -1).unsqueeze(0), + batch.view(curr_dim, -1).unsqueeze(0), + p=self.p_norm_distance, + ).squeeze(0) + + r, w = torch.triu_indices(*batch_batch_dist.shape) # Remove upper diag. + batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max + batch_batch_dist = batch_batch_dist.min(-1)[0] + + # Filter the batch for diverse final_states w.r.t the buffer. + batch_buffer_dist = ( + torch.cdist( + batch.view(curr_dim, -1).unsqueeze(0), + buffer.view(buffer_dim, -1).unsqueeze(0), + p=self.p_norm_distance, + ) + .squeeze(0) + .min(-1)[0] + ) + + # Remove non-diverse examples according to the above distances. + idx_batch_batch = batch_batch_dist > self.cutoff_distance + idx_batch_buffer = batch_buffer_dist > self.cutoff_distance + idx_diverse = idx_batch_batch & idx_batch_buffer + + training_objects = training_objects[idx_diverse] + + # If any training object remain after filtering, add them. + if len(training_objects): + self._add_objs(training_objects)