From e087f414bc9e63c83815e59756cb96e262d5045a Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 30 Mar 2024 15:54:57 -0400 Subject: [PATCH 01/10] log_rewards are stored properly in the case that the external trajectory contains log_rewards and the internal trajectory is None (this can happen with empty initalized trajectory) --- src/gfn/containers/trajectories.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 35196ec3..6002a330 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -232,22 +232,34 @@ def extend(self, other: Trajectories) -> None: self.states.extend(other.states) self.when_is_done = torch.cat((self.when_is_done, other.when_is_done), dim=0) - # For log_probs, we first need to make the first dimensions of self.log_probs and other.log_probs equal - # (i.e. the number of steps in the trajectories), and then concatenate them + # For log_probs, we first need to make the first dimensions of self.log_probs + # and other.log_probs equal (i.e. the number of steps in the trajectories), and + # then concatenate them. new_max_length = max(self.log_probs.shape[0], other.log_probs.shape[0]) self.log_probs = self.extend_log_probs(self.log_probs, new_max_length) other.log_probs = self.extend_log_probs(other.log_probs, new_max_length) - self.log_probs = torch.cat((self.log_probs, other.log_probs), dim=1) + # Concatenate log_rewards of the trajectories. if self._log_rewards is not None and other._log_rewards is not None: self._log_rewards = torch.cat( (self._log_rewards, other._log_rewards), dim=0, ) + # If the trajectories object does not yet have `log_rewards` assigned but the + # external trajectory has log_rewards, simply assign them over. + elif self._log_rewards is None and other._log_rewards is not None: + self._log_rewards = other._log_rewards else: self._log_rewards = None + # Ensure log_probs/rewards are the correct dimensions. TODO: Remove? + if self.log_probs.numel() > 0: + assert len(self.log_probs) == self.states.batch_shape[-1] + + if self.log_rewards is not None: + assert len(self.log_rewards) == self.states.batch_shape[-1] + # Either set, or append, estimator outputs if they exist in the submitted # trajectory. if self.estimator_outputs is None and isinstance( From 61f7fd281575cd48d8af0d262ee02bacf92b324f Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 30 Mar 2024 18:55:51 -0400 Subject: [PATCH 02/10] can use either standard or prioritized replay buffer --- tutorials/examples/train_hypergrid.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 2041c7ca..efef2645 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -17,7 +17,7 @@ import wandb from tqdm import tqdm, trange -from gfn.containers import ReplayBuffer +from gfn.containers import ReplayBuffer, PrioritizedReplayBuffer from gfn.gflownet import ( DBGFlowNet, FMGFlowNet, @@ -185,12 +185,17 @@ def main(args): # noqa: C901 objects_type = "states" else: raise NotImplementedError(f"Unknown loss: {args.loss}") - replay_buffer = ReplayBuffer( - env, objects_type=objects_type, capacity=args.replay_buffer_size - ) - # 3. Create the optimizer + if args.replay_buffer_prioritized: + replay_buffer = PrioritizedReplayBuffer( + env, objects_type=objects_type, capacity=args.replay_buffer_size + ) + else: + replay_buffer = ReplayBuffer( + env, objects_type=objects_type, capacity=args.replay_buffer_size + ) + # 3. Create the optimizer # Policy parameters have their own LR. params = [ { @@ -292,6 +297,11 @@ def main(args): # noqa: C901 default=0, help="If zero, no replay buffer is used. Otherwise, the replay buffer is used.", ) + parser.add_argument( + "--replay_buffer_prioritized", + action="store_true", + help="If set and replay_buffer_size > 0, use a prioritized replay buffer.", + ) parser.add_argument( "--loss", From 75e319844fcfe44218143a05fee5975a40fd79df Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 30 Mar 2024 18:56:18 -0400 Subject: [PATCH 03/10] added a prioritized replay buffer --- src/gfn/containers/__init__.py | 2 +- src/gfn/containers/replay_buffer.py | 135 +++++++++++++++++++++++++++- 2 files changed, 135 insertions(+), 2 deletions(-) diff --git a/src/gfn/containers/__init__.py b/src/gfn/containers/__init__.py index f3c3c9a5..0acaab06 100644 --- a/src/gfn/containers/__init__.py +++ b/src/gfn/containers/__init__.py @@ -1,3 +1,3 @@ -from .replay_buffer import ReplayBuffer +from .replay_buffer import ReplayBuffer, PrioritizedReplayBuffer from .trajectories import Trajectories from .transitions import Transitions diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index c1679d9a..ee24c882 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -18,7 +18,6 @@ class ReplayBuffer: Attributes: env: the Environment instance. - loss_fn: the Loss instance capacity: the size of the buffer. training_objects: the buffer of objects used for training. terminating_states: a States class representation of $s_f$. @@ -105,3 +104,137 @@ def load(self, directory: str): self._index = len(self.training_objects) if self.terminating_states is not None: self.terminating_states.load(os.path.join(directory, "terminating_states")) + + +class PrioritizedReplayBuffer(ReplayBuffer): + """A replay buffer of trajectories or transitions. + + Attributes: + env: the Environment instance. + capacity: the size of the buffer. + training_objects: the buffer of objects used for training. + terminating_states: a States class representation of $s_f$. + objects_type: the type of buffer (transitions, trajectories, or states). + cutoff_distance: threshold used to determine if new last_states are different + enough from those already contained in the buffer. + p_norm_distance: p-norm distance value to pass to torch.cdist, for the + determination of novel states. + """ + def __init__( + self, + env: Env, + objects_type: Literal["transitions", "trajectories", "states"], + capacity: int = 1000, + cutoff_distance: float = 0., + p_norm_distance: float = 1., + ): + """Instantiates a prioritized replay buffer. + Args: + env: the Environment instance. + loss_fn: the Loss instance. + capacity: the size of the buffer. + objects_type: the type of buffer (transitions, trajectories, or states). + cutoff_distance: threshold used to determine if new last_states are + different enough from those already contained in the buffer. + p_norm_distance: p-norm distance value to pass to torch.cdist, for the + determination of novel states. + """ + super().__init__(env, objects_type, capacity) + self.cutoff_distance = cutoff_distance + self.p_norm_distance = p_norm_distance + + def _add_objs(self, training_objects: Transitions | Trajectories | tuple[States]): + """Adds a training object to the buffer.""" + # Adds the objects to the buffer. + self.training_objects.extend(training_objects) + + # Sort elements by logreward, capping the size at the defined capacity. + ix = torch.argsort(self.training_objects.log_rewards) + self.training_objects = self.training_objects[ix] + self.training_objects = self.training_objects[-self.capacity :] + + # Add the terminating states to the buffer. + if self.terminating_states is not None: + assert terminating_states is not None + self.terminating_states.extend(terminating_states) + + # Sort terminating states by logreward as well. + self.terminating_states = self.terminating_states[ix] + self.terminating_states = self.terminating_states[-self.capacity :] + + def add(self, training_objects: Transitions | Trajectories | tuple[States]): + """Adds a training object to the buffer.""" + terminating_states = None + if isinstance(training_objects, tuple): + assert self.objects_type == "states" and self.terminating_states is not None + training_objects, terminating_states = training_objects + + to_add = len(training_objects) + + self._is_full |= self._index + to_add >= self.capacity + self._index = (self._index + to_add) % self.capacity + + # The buffer isn't full yet. + if len(self.training_objects) < self.capacity: + self._add_objs(training_objects) + + # Our buffer is full and we will prioritize diverse, high reward additions. + else: + # Sort the incoming elements by their logrewards. + ix = torch.argsort(training_objects._log_rewards, descending=True) + training_objects = training_objects[ix] + + # Filter all batch logrewards lower than the smallest logreward in buffer. + min_reward_in_buffer = self.training_objects.log_rewards.min() + idx_bigger_rewards = training_objects.log_rewards > min_reward_in_buffer + training_objects = training_objects[idx_bigger_rewards] + + # Compute all pairwise distances between the batch and the buffer. + curr_dim = training_objects.last_states.batch_shape[0] + buffer_dim = self.training_objects.last_states.batch_shape[0] + + # TODO: Concatenate input with final state for conditional GFN. + # if self.is_conditional: + # batch = torch.cat( + # [dict_curr_batch["input"], dict_curr_batch["final_state"]], + # dim=-1, + # ) + # buffer = torch.cat( + # [self.storage["input"], self.storage["final_state"]], + # dim=-1, + # ) + batch = training_objects.last_states.tensor.float() + buffer = self.training_objects.last_states.tensor.float() + + # Filter the batch for diverse final_states with high reward. + batch_batch_dist = torch.cdist( + batch.view(curr_dim, -1).unsqueeze(0), + batch.view(curr_dim, -1).unsqueeze(0), + p=self.p_norm_distance, + ).squeeze(0) + + r, w = torch.triu_indices(*batch_batch_dist.shape) # Remove upper diag. + batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max + batch_batch_dist = batch_batch_dist.min(-1)[0] + + # Filter the batch for diverse final_states w.r.t the buffer. + batch_buffer_dist = ( + torch.cdist( + batch.view(curr_dim, -1).unsqueeze(0), + buffer.view(buffer_dim, -1).unsqueeze(0), + p=self.p_norm_distance, + ) + .squeeze(0) + .min(-1)[0] + ) + + # Remove non-diverse examples according to the above distances. + idx_batch_batch = batch_batch_dist > self.cutoff_distance + idx_batch_buffer = batch_buffer_dist > self.cutoff_distance + idx_diverse = idx_batch_batch & idx_batch_buffer + + training_objects = training_objects[idx_diverse] + + # If any training object remain after filtering, add them. + if len(training_objects): + self._add_objs(training_objects) From 7449a923d535c13d5477cf938edc35f403bc16e0 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sun, 31 Mar 2024 12:27:24 -0400 Subject: [PATCH 04/10] bugfix on assert --- src/gfn/containers/trajectories.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 6002a330..d52274f8 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -255,10 +255,10 @@ def extend(self, other: Trajectories) -> None: # Ensure log_probs/rewards are the correct dimensions. TODO: Remove? if self.log_probs.numel() > 0: - assert len(self.log_probs) == self.states.batch_shape[-1] + assert self.log_probs.shape == self.actions.batch_shape if self.log_rewards is not None: - assert len(self.log_rewards) == self.states.batch_shape[-1] + assert len(self.log_rewards) == self.actions.batch_shape[-1] # Either set, or append, estimator outputs if they exist in the submitted # trajectory. From 184c5f59f0fddeb5fdb8a01135a581782117362a Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 2 Apr 2024 12:05:43 -0400 Subject: [PATCH 05/10] removed self._index --- src/gfn/containers/replay_buffer.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index ee24c882..2bc870ca 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -55,13 +55,12 @@ def __init__( raise ValueError(f"Unknown objects_type: {objects_type}") self._is_full = False - self._index = 0 def __repr__(self): return f"ReplayBuffer(capacity={self.capacity}, containing {len(self)} {self.objects_type})" def __len__(self): - return self.capacity if self._is_full else self._index + return len(self.training_objects) def add(self, training_objects: Transitions | Trajectories | tuple[States]): """Adds a training object to the buffer.""" @@ -72,8 +71,7 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]): to_add = len(training_objects) - self._is_full |= self._index + to_add >= self.capacity - self._index = (self._index + to_add) % self.capacity + self._is_full |= len(self) + to_add >= self.capacity self.training_objects.extend(training_objects) self.training_objects = self.training_objects[-self.capacity :] @@ -101,7 +99,6 @@ def save(self, directory: str): def load(self, directory: str): """Loads the buffer from disk.""" self.training_objects.load(os.path.join(directory, "training_objects")) - self._index = len(self.training_objects) if self.terminating_states is not None: self.terminating_states.load(os.path.join(directory, "terminating_states")) @@ -171,8 +168,7 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]): to_add = len(training_objects) - self._is_full |= self._index + to_add >= self.capacity - self._index = (self._index + to_add) % self.capacity + self._is_full |= len(self) + to_add >= self.capacity # The buffer isn't full yet. if len(self.training_objects) < self.capacity: From 00cab17381fe001859b7aed0aa2afd151428204f Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 2 Apr 2024 12:06:28 -0400 Subject: [PATCH 06/10] trajectories and transitions now initalize log_rewards correctly, so extend always works, which means we no longer need the extra condition (but we are leaving in the sanity check for now). --- src/gfn/containers/trajectories.py | 11 ++++++----- src/gfn/containers/transitions.py | 4 +++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index d52274f8..cc02bda1 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -90,7 +90,11 @@ def __init__( if when_is_done is not None else torch.full(size=(0,), fill_value=-1, dtype=torch.long) ) - self._log_rewards = log_rewards + self._log_rewards = ( + log_rewards + if log_rewards is not None + else torch.full(size=(0,), fill_value=0, dtype=torch.float) + ) self.log_probs = ( log_probs if log_probs is not None @@ -246,10 +250,7 @@ def extend(self, other: Trajectories) -> None: (self._log_rewards, other._log_rewards), dim=0, ) - # If the trajectories object does not yet have `log_rewards` assigned but the - # external trajectory has log_rewards, simply assign them over. - elif self._log_rewards is None and other._log_rewards is not None: - self._log_rewards = other._log_rewards + # Will not be None if object is initialized as empty. else: self._log_rewards = None diff --git a/src/gfn/containers/transitions.py b/src/gfn/containers/transitions.py index a3c920af..8646e238 100644 --- a/src/gfn/containers/transitions.py +++ b/src/gfn/containers/transitions.py @@ -90,7 +90,7 @@ def __init__( len(self.next_states.batch_shape) == 1 and self.states.batch_shape == self.next_states.batch_shape ) - self._log_rewards = log_rewards + self._log_rewards = log_rewards if log_rewards is not None else torch.zeros(0) self.log_probs = log_probs if log_probs is not None else torch.zeros(0) @property @@ -208,6 +208,8 @@ def extend(self, other: Transitions) -> None: self.actions.extend(other.actions) self.is_done = torch.cat((self.is_done, other.is_done), dim=0) self.next_states.extend(other.next_states) + + # Concatenate log_rewards of the trajectories. if self._log_rewards is not None and other._log_rewards is not None: self._log_rewards = torch.cat( (self._log_rewards, other._log_rewards), dim=0 From a6703569565cbb611eb0a4f145c1d28fccf941a8 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 2 Apr 2024 12:46:28 -0400 Subject: [PATCH 07/10] small efficiency improvements to prioritized replay buffer --- src/gfn/containers/replay_buffer.py | 37 ++++++++++++++--------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index 2bc870ca..fe1b1110 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -167,7 +167,6 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]): training_objects, terminating_states = training_objects to_add = len(training_objects) - self._is_full |= len(self) + to_add >= self.capacity # The buffer isn't full yet. @@ -177,18 +176,14 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]): # Our buffer is full and we will prioritize diverse, high reward additions. else: # Sort the incoming elements by their logrewards. - ix = torch.argsort(training_objects._log_rewards, descending=True) + ix = torch.argsort(training_objects.log_rewards, descending=True) training_objects = training_objects[ix] # Filter all batch logrewards lower than the smallest logreward in buffer. min_reward_in_buffer = self.training_objects.log_rewards.min() - idx_bigger_rewards = training_objects.log_rewards > min_reward_in_buffer + idx_bigger_rewards = training_objects.log_rewards >= min_reward_in_buffer training_objects = training_objects[idx_bigger_rewards] - # Compute all pairwise distances between the batch and the buffer. - curr_dim = training_objects.last_states.batch_shape[0] - buffer_dim = self.training_objects.last_states.batch_shape[0] - # TODO: Concatenate input with final state for conditional GFN. # if self.is_conditional: # batch = torch.cat( @@ -199,37 +194,41 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]): # [self.storage["input"], self.storage["final_state"]], # dim=-1, # ) - batch = training_objects.last_states.tensor.float() - buffer = self.training_objects.last_states.tensor.float() # Filter the batch for diverse final_states with high reward. + batch = training_objects.last_states.tensor.float() + batch_dim = training_objects.last_states.batch_shape[0] batch_batch_dist = torch.cdist( - batch.view(curr_dim, -1).unsqueeze(0), - batch.view(curr_dim, -1).unsqueeze(0), + batch.view(batch_dim, -1).unsqueeze(0), + batch.view(batch_dim, -1).unsqueeze(0), p=self.p_norm_distance, ).squeeze(0) + # Finds the min distance at each row, and removes rows below the cutoff. r, w = torch.triu_indices(*batch_batch_dist.shape) # Remove upper diag. batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max batch_batch_dist = batch_batch_dist.min(-1)[0] + idx_batch_batch = batch_batch_dist > self.cutoff_distance + training_objects = training_objects[idx_batch_batch] - # Filter the batch for diverse final_states w.r.t the buffer. + # Compute all pairwise distances between the remaining batch and the buffer. + batch = training_objects.last_states.tensor.float() + buffer = self.training_objects.last_states.tensor.float() + batch_dim = training_objects.last_states.batch_shape[0] + buffer_dim = self.training_objects.last_states.batch_shape[0] batch_buffer_dist = ( torch.cdist( - batch.view(curr_dim, -1).unsqueeze(0), + batch.view(batch_dim, -1).unsqueeze(0), buffer.view(buffer_dim, -1).unsqueeze(0), p=self.p_norm_distance, ) .squeeze(0) - .min(-1)[0] + .min(-1)[0] # Min calculated over rows, i.e., over the batch elements. ) - # Remove non-diverse examples according to the above distances. - idx_batch_batch = batch_batch_dist > self.cutoff_distance + # Filter the batch for diverse final_states w.r.t the buffer. idx_batch_buffer = batch_buffer_dist > self.cutoff_distance - idx_diverse = idx_batch_batch & idx_batch_buffer - - training_objects = training_objects[idx_diverse] + training_objects = training_objects[idx_batch_buffer] # If any training object remain after filtering, add them. if len(training_objects): From 213653c0de16808fe477e438fe80a43d6c429ffe Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 2 Apr 2024 17:38:13 -0400 Subject: [PATCH 08/10] changes to default settings for hypergrid --- tutorials/examples/train_hypergrid.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index efef2645..eec3366b 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -188,7 +188,11 @@ def main(args): # noqa: C901 if args.replay_buffer_prioritized: replay_buffer = PrioritizedReplayBuffer( - env, objects_type=objects_type, capacity=args.replay_buffer_size + env, + objects_type=objects_type, + capacity=args.replay_buffer_size, + p_norm_distance=1, # Use L1-norm for diversity estimation. + cutoff_distance=0, # -1 turns off diversity-based filtering. ) else: replay_buffer = ReplayBuffer( From 04e50e945809a914563c6c391b86ae939b838a9a Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 2 Apr 2024 17:38:55 -0400 Subject: [PATCH 09/10] reworking of prioritized replay buffer logic --- src/gfn/containers/replay_buffer.py | 65 +++++++++++++++-------------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/src/gfn/containers/replay_buffer.py b/src/gfn/containers/replay_buffer.py index fe1b1110..2cf9fcc6 100644 --- a/src/gfn/containers/replay_buffer.py +++ b/src/gfn/containers/replay_buffer.py @@ -132,7 +132,9 @@ def __init__( capacity: the size of the buffer. objects_type: the type of buffer (transitions, trajectories, or states). cutoff_distance: threshold used to determine if new last_states are - different enough from those already contained in the buffer. + different enough from those already contained in the buffer. If the + cutoff is negative, all diversity caclulations are skipped (since all + norms are >= 0). p_norm_distance: p-norm distance value to pass to torch.cdist, for the determination of novel states. """ @@ -195,40 +197,41 @@ def add(self, training_objects: Transitions | Trajectories | tuple[States]): # dim=-1, # ) - # Filter the batch for diverse final_states with high reward. - batch = training_objects.last_states.tensor.float() - batch_dim = training_objects.last_states.batch_shape[0] - batch_batch_dist = torch.cdist( - batch.view(batch_dim, -1).unsqueeze(0), - batch.view(batch_dim, -1).unsqueeze(0), - p=self.p_norm_distance, - ).squeeze(0) - - # Finds the min distance at each row, and removes rows below the cutoff. - r, w = torch.triu_indices(*batch_batch_dist.shape) # Remove upper diag. - batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max - batch_batch_dist = batch_batch_dist.min(-1)[0] - idx_batch_batch = batch_batch_dist > self.cutoff_distance - training_objects = training_objects[idx_batch_batch] - - # Compute all pairwise distances between the remaining batch and the buffer. - batch = training_objects.last_states.tensor.float() - buffer = self.training_objects.last_states.tensor.float() - batch_dim = training_objects.last_states.batch_shape[0] - buffer_dim = self.training_objects.last_states.batch_shape[0] - batch_buffer_dist = ( - torch.cdist( + if self.cutoff_distance >= 0: + # Filter the batch for diverse final_states with high reward. + batch = training_objects.last_states.tensor.float() + batch_dim = training_objects.last_states.batch_shape[0] + batch_batch_dist = torch.cdist( + batch.view(batch_dim, -1).unsqueeze(0), batch.view(batch_dim, -1).unsqueeze(0), - buffer.view(buffer_dim, -1).unsqueeze(0), p=self.p_norm_distance, + ).squeeze(0) + + # Finds the min distance at each row, and removes rows below the cutoff. + r, w = torch.triu_indices(*batch_batch_dist.shape) # Remove upper diag. + batch_batch_dist[r, w] = torch.finfo(batch_batch_dist.dtype).max + batch_batch_dist = batch_batch_dist.min(-1)[0] + idx_batch_batch = batch_batch_dist > self.cutoff_distance + training_objects = training_objects[idx_batch_batch] + + # Compute all pairwise distances between the remaining batch & buffer. + batch = training_objects.last_states.tensor.float() + buffer = self.training_objects.last_states.tensor.float() + batch_dim = training_objects.last_states.batch_shape[0] + buffer_dim = self.training_objects.last_states.batch_shape[0] + batch_buffer_dist = ( + torch.cdist( + batch.view(batch_dim, -1).unsqueeze(0), + buffer.view(buffer_dim, -1).unsqueeze(0), + p=self.p_norm_distance, + ) + .squeeze(0) + .min(-1)[0] # Min calculated over rows - the batch elements. ) - .squeeze(0) - .min(-1)[0] # Min calculated over rows, i.e., over the batch elements. - ) - # Filter the batch for diverse final_states w.r.t the buffer. - idx_batch_buffer = batch_buffer_dist > self.cutoff_distance - training_objects = training_objects[idx_batch_buffer] + # Filter the batch for diverse final_states w.r.t the buffer. + idx_batch_buffer = batch_buffer_dist > self.cutoff_distance + training_objects = training_objects[idx_batch_buffer] # If any training object remain after filtering, add them. if len(training_objects): From 54720553f58416979e66719d928634af3d63fd00 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 2 Apr 2024 17:41:12 -0400 Subject: [PATCH 10/10] added comment --- src/gfn/containers/transitions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gfn/containers/transitions.py b/src/gfn/containers/transitions.py index 8646e238..cbc214f6 100644 --- a/src/gfn/containers/transitions.py +++ b/src/gfn/containers/transitions.py @@ -214,6 +214,7 @@ def extend(self, other: Transitions) -> None: self._log_rewards = torch.cat( (self._log_rewards, other._log_rewards), dim=0 ) + # Will not be None if object is initialized as empty. else: self._log_rewards = None self.log_probs = torch.cat((self.log_probs, other.log_probs), dim=0)