From e087f414bc9e63c83815e59756cb96e262d5045a Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 30 Mar 2024 15:54:57 -0400 Subject: [PATCH] log_rewards are stored properly in the case that the external trajectory contains log_rewards and the internal trajectory is None (this can happen with empty initalized trajectory) --- src/gfn/containers/trajectories.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 35196ec3..6002a330 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -232,22 +232,34 @@ def extend(self, other: Trajectories) -> None: self.states.extend(other.states) self.when_is_done = torch.cat((self.when_is_done, other.when_is_done), dim=0) - # For log_probs, we first need to make the first dimensions of self.log_probs and other.log_probs equal - # (i.e. the number of steps in the trajectories), and then concatenate them + # For log_probs, we first need to make the first dimensions of self.log_probs + # and other.log_probs equal (i.e. the number of steps in the trajectories), and + # then concatenate them. new_max_length = max(self.log_probs.shape[0], other.log_probs.shape[0]) self.log_probs = self.extend_log_probs(self.log_probs, new_max_length) other.log_probs = self.extend_log_probs(other.log_probs, new_max_length) - self.log_probs = torch.cat((self.log_probs, other.log_probs), dim=1) + # Concatenate log_rewards of the trajectories. if self._log_rewards is not None and other._log_rewards is not None: self._log_rewards = torch.cat( (self._log_rewards, other._log_rewards), dim=0, ) + # If the trajectories object does not yet have `log_rewards` assigned but the + # external trajectory has log_rewards, simply assign them over. + elif self._log_rewards is None and other._log_rewards is not None: + self._log_rewards = other._log_rewards else: self._log_rewards = None + # Ensure log_probs/rewards are the correct dimensions. TODO: Remove? + if self.log_probs.numel() > 0: + assert len(self.log_probs) == self.states.batch_shape[-1] + + if self.log_rewards is not None: + assert len(self.log_rewards) == self.states.batch_shape[-1] + # Either set, or append, estimator outputs if they exist in the submitted # trajectory. if self.estimator_outputs is None and isinstance(