diff --git a/src/gfn/containers/__init__.py b/src/gfn/containers/__init__.py index f3c3c9a5..0acaab06 100644 --- a/src/gfn/containers/__init__.py +++ b/src/gfn/containers/__init__.py @@ -1,3 +1,3 @@ -from .replay_buffer import ReplayBuffer +from .replay_buffer import ReplayBuffer, PrioritizedReplayBuffer from .trajectories import Trajectories from .transitions import Transitions diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index c1679d9a..2cf9fcc6 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -18,7 +18,6 @@ class ReplayBuffer: Attributes: env: the Environment instance. - loss_fn: the Loss instance capacity: the size of the buffer. training_objects: the buffer of objects used for training. terminating_states: a States class representation of $s_f$. @@ -56,13 +55,12 @@ def __init__( raise ValueError(f"Unknown objects_type: {objects_type}") self._is_full = False - self._index = 0 def __repr__(self): return f"ReplayBuffer(capacity={self.capacity}, containing {len(self)} {self.objects_type})" def __len__(self): - return self.capacity if self._is_full else self._index + return len(self.training_objects) def add(self, training_objects: Transitions | Trajectories | tuple[States]): """Adds a training object to the buffer.""" @@ -73,8 +71,7 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]): to_add = len(training_objects) - self._is_full |= self._index + to_add >= self.capacity - self._index = (self._index + to_add) % self.capacity + self._is_full |= len(self) + to_add >= self.capacity self.training_objects.extend(training_objects) self.training_objects = self.training_objects[-self.capacity :] @@ -102,6 +99,140 @@ def save(self, directory: str): def load(self, directory: str): """Loads the buffer from disk.""" self.training_objects.load(os.path.join(directory, "training_objects")) - self._index = len(self.training_objects) if self.terminating_states is not None: self.terminating_states.load(os.path.join(directory, "terminating_states")) + + +class PrioritizedReplayBuffer(ReplayBuffer): + """A replay buffer of trajectories or transitions. + + Attributes: + env: the Environment instance. + capacity: the size of the buffer. + training_objects: the buffer of objects used for training. + terminating_states: a States class representation of $s_f$. + objects_type: the type of buffer (transitions, trajectories, or states). + cutoff_distance: threshold used to determine if new last_states are different + enough from those already contained in the buffer. + p_norm_distance: p-norm distance value to pass to torch.cdist, for the + determination of novel states. + """ + def __init__( + self, + env: Env, + objects_type: Literal["transitions", "trajectories", "states"], + capacity: int = 1000, + cutoff_distance: float = 0., + p_norm_distance: float = 1., + ): + """Instantiates a prioritized replay buffer. + Args: + env: the Environment instance. + loss_fn: the Loss instance. + capacity: the size of the buffer. + objects_type: the type of buffer (transitions, trajectories, or states). + cutoff_distance: threshold used to determine if new last_states are + different enough from those already contained in the buffer. If the + cutoff is negative, all diversity caclulations are skipped (since all + norms are >= 0). + p_norm_distance: p-norm distance value to pass to torch.cdist, for the + determination of novel states. + """ + super().__init__(env, objects_type, capacity) + self.cutoff_distance = cutoff_distance + self.p_norm_distance = p_norm_distance + + def _add_objs(self, training_objects: Transitions | Trajectories | tuple[States]): + """Adds a training object to the buffer.""" + # Adds the objects to the buffer. + self.training_objects.extend(training_objects) + + # Sort elements by logreward, capping the size at the defined capacity. + ix = torch.argsort(self.training_objects.log_rewards) + self.training_objects = self.training_objects[ix] + self.training_objects = self.training_objects[-self.capacity :] + + # Add the terminating states to the buffer. + if self.terminating_states is not None: + assert terminating_states is not None + self.terminating_states.extend(terminating_states) + + # Sort terminating states by logreward as well. + self.terminating_states = self.terminating_states[ix] + self.terminating_states = self.terminating_states[-self.capacity :] + + def add(self, training_objects: Transitions | Trajectories | tuple[States]): + """Adds a training object to the buffer.""" + terminating_states = None + if isinstance(training_objects, tuple): + assert self.objects_type == "states" and self.terminating_states is not None + training_objects, terminating_states = training_objects + + to_add = len(training_objects) + self._is_full |= len(self) + to_add >= self.capacity + + # The buffer isn't full yet. + if len(self.training_objects) < self.capacity: + self._add_objs(training_objects) + + # Our buffer is full and we will prioritize diverse, high reward additions. + else: + # Sort the incoming elements by their logrewards. + ix = torch.argsort(training_objects.log_rewards, descending=True) + training_objects = training_objects[ix] + + # Filter all batch logrewards lower than the smallest logreward in buffer. + min_reward_in_buffer = self.training_objects.log_rewards.min() + idx_bigger_rewards = training_objects.log_rewards >= min_reward_in_buffer + training_objects = training_objects[idx_bigger_rewards] + + # TODO: Concatenate input with final state for conditional GFN. + # if self.is_conditional: + # batch = torch.cat( + # [dict_curr_batch["input"], dict_curr_batch["final_state"]], + # dim=-1, + # ) + # buffer = torch.cat( + # [self.storage["input"], self.storage["final_state"]], + # dim=-1, + # ) + + if self.cutoff_distance >= 0: + # Filter the batch for diverse final_states with high reward. + batch = training_objects.last_states.tensor.float() + batch_dim = training_objects.last_states.batch_shape[0] + batch_batch_dist = torch.cdist( + batch.view(batch_dim, -1).unsqueeze(0), + batch.view(batch_dim, -1).unsqueeze(0), + p=self.p_norm_distance, + ).squeeze(0) + + # Finds the min distance at each row, and removes rows below the cutoff. + r, w = torch.triu_indices(*batch_batch_dist.shape) # Remove upper diag. + batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max + batch_batch_dist = batch_batch_dist.min(-1)[0] + idx_batch_batch = batch_batch_dist > self.cutoff_distance + training_objects = training_objects[idx_batch_batch] + + # Compute all pairwise distances between the remaining batch & buffer. + batch = training_objects.last_states.tensor.float() + buffer = self.training_objects.last_states.tensor.float() + batch_dim = training_objects.last_states.batch_shape[0] + buffer_dim = self.training_objects.last_states.batch_shape[0] + batch_buffer_dist = ( + torch.cdist( + batch.view(batch_dim, -1).unsqueeze(0), + buffer.view(buffer_dim, -1).unsqueeze(0), + p=self.p_norm_distance, + ) + .squeeze(0) + .min(-1)[0] # Min calculated over rows - the batch elements. + ) + + # Filter the batch for diverse final_states w.r.t the buffer. + idx_batch_buffer = batch_buffer_dist > self.cutoff_distance + training_objects = training_objects[idx_batch_buffer] + + # If any training object remain after filtering, add them. + if len(training_objects): + self._add_objs(training_objects) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 35196ec3..cc02bda1 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -90,7 +90,11 @@ def __init__( if when_is_done is not None else torch.full(size=(0,), fill_value=-1, dtype=torch.long) ) - self._log_rewards = log_rewards + self._log_rewards = ( + log_rewards + if log_rewards is not None + else torch.full(size=(0,), fill_value=0, dtype=torch.float) + ) self.log_probs = ( log_probs if log_probs is not None @@ -232,22 +236,31 @@ def extend(self, other: Trajectories) -> None: self.states.extend(other.states) self.when_is_done = torch.cat((self.when_is_done, other.when_is_done), dim=0) - # For log_probs, we first need to make the first dimensions of self.log_probs and other.log_probs equal - # (i.e. the number of steps in the trajectories), and then concatenate them + # For log_probs, we first need to make the first dimensions of self.log_probs + # and other.log_probs equal (i.e. the number of steps in the trajectories), and + # then concatenate them. new_max_length = max(self.log_probs.shape[0], other.log_probs.shape[0]) self.log_probs = self.extend_log_probs(self.log_probs, new_max_length) other.log_probs = self.extend_log_probs(other.log_probs, new_max_length) - self.log_probs = torch.cat((self.log_probs, other.log_probs), dim=1) + # Concatenate log_rewards of the trajectories. if self._log_rewards is not None and other._log_rewards is not None: self._log_rewards = torch.cat( (self._log_rewards, other._log_rewards), dim=0, ) + # Will not be None if object is initialized as empty. else: self._log_rewards = None + # Ensure log_probs/rewards are the correct dimensions. TODO: Remove? + if self.log_probs.numel() > 0: + assert self.log_probs.shape == self.actions.batch_shape + + if self.log_rewards is not None: + assert len(self.log_rewards) == self.actions.batch_shape[-1] + # Either set, or append, estimator outputs if they exist in the submitted # trajectory. if self.estimator_outputs is None and isinstance( diff --git a/src/gfn/containers/transitions.py b/src/gfn/containers/transitions.py index a3c920af..cbc214f6 100644 --- a/src/gfn/containers/transitions.py +++ b/src/gfn/containers/transitions.py @@ -90,7 +90,7 @@ def __init__( len(self.next_states.batch_shape) == 1 and self.states.batch_shape == self.next_states.batch_shape ) - self._log_rewards = log_rewards + self._log_rewards = log_rewards if log_rewards is not None else torch.zeros(0) self.log_probs = log_probs if log_probs is not None else torch.zeros(0) @property @@ -208,10 +208,13 @@ def extend(self, other: Transitions) -> None: self.actions.extend(other.actions) self.is_done = torch.cat((self.is_done, other.is_done), dim=0) self.next_states.extend(other.next_states) + + # Concatenate log_rewards of the trajectories. if self._log_rewards is not None and other._log_rewards is not None: self._log_rewards = torch.cat( (self._log_rewards, other._log_rewards), dim=0 ) + # Will not be None if object is initialized as empty. else: self._log_rewards = None self.log_probs = torch.cat((self.log_probs, other.log_probs), dim=0) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 2041c7ca..eec3366b 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -17,7 +17,7 @@ import wandb from tqdm import tqdm, trange -from gfn.containers import ReplayBuffer +from gfn.containers import ReplayBuffer, PrioritizedReplayBuffer from gfn.gflownet import ( DBGFlowNet, FMGFlowNet, @@ -185,12 +185,21 @@ def main(args): # noqa: C901 objects_type = "states" else: raise NotImplementedError(f"Unknown loss: {args.loss}") - replay_buffer = ReplayBuffer( - env, objects_type=objects_type, capacity=args.replay_buffer_size - ) - # 3. Create the optimizer + if args.replay_buffer_prioritized: + replay_buffer = PrioritizedReplayBuffer( + env, + objects_type=objects_type, + capacity=args.replay_buffer_size, + p_norm_distance=1, # Use L1-norm for diversity estimation. + cutoff_distance=0, # -1 turns off diversity-based filtering. + ) + else: + replay_buffer = ReplayBuffer( + env, objects_type=objects_type, capacity=args.replay_buffer_size + ) + # 3. Create the optimizer # Policy parameters have their own LR. params = [ { @@ -292,6 +301,11 @@ def main(args): # noqa: C901 default=0, help="If zero, no replay buffer is used. Otherwise, the replay buffer is used.", ) + parser.add_argument( + "--replay_buffer_prioritized", + action="store_true", + help="If set and replay_buffer_size > 0, use a prioritized replay buffer.", + ) parser.add_argument( "--loss",