Skip to content

Commit

Permalink
log_rewards are stored properly in the case that the external traject…
Browse files Browse the repository at this point in the history
…ory contains log_rewards and the internal trajectory is None (this can happen with empty initalized trajectory)
  • Loading branch information
josephdviviano committed Mar 30, 2024
1 parent a50af8e commit e087f41
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,22 +232,34 @@ def extend(self, other: Trajectories) -> None:
self.states.extend(other.states)
self.when_is_done = torch.cat((self.when_is_done, other.when_is_done), dim=0)

# For log_probs, we first need to make the first dimensions of self.log_probs and other.log_probs equal
# (i.e. the number of steps in the trajectories), and then concatenate them
# For log_probs, we first need to make the first dimensions of self.log_probs
# and other.log_probs equal (i.e. the number of steps in the trajectories), and
# then concatenate them.
new_max_length = max(self.log_probs.shape[0], other.log_probs.shape[0])
self.log_probs = self.extend_log_probs(self.log_probs, new_max_length)
other.log_probs = self.extend_log_probs(other.log_probs, new_max_length)

self.log_probs = torch.cat((self.log_probs, other.log_probs), dim=1)

# Concatenate log_rewards of the trajectories.
if self._log_rewards is not None and other._log_rewards is not None:
self._log_rewards = torch.cat(
(self._log_rewards, other._log_rewards),
dim=0,
)
# If the trajectories object does not yet have `log_rewards` assigned but the
# external trajectory has log_rewards, simply assign them over.
elif self._log_rewards is None and other._log_rewards is not None:
self._log_rewards = other._log_rewards
else:
self._log_rewards = None

# Ensure log_probs/rewards are the correct dimensions. TODO: Remove?
if self.log_probs.numel() > 0:
assert len(self.log_probs) == self.states.batch_shape[-1]

if self.log_rewards is not None:
assert len(self.log_rewards) == self.states.batch_shape[-1]

# Either set, or append, estimator outputs if they exist in the submitted
# trajectory.
if self.estimator_outputs is None and isinstance(
Expand Down

0 comments on commit e087f41

Please sign in to comment.