Skip to content

Commit

Permalink
bugfix on assert
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Mar 31, 2024
1 parent 75e3198 commit 7449a92
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,10 @@ def extend(self, other: Trajectories) -> None:

# Ensure log_probs/rewards are the correct dimensions. TODO: Remove?
if self.log_probs.numel() > 0:
assert len(self.log_probs) == self.states.batch_shape[-1]
assert self.log_probs.shape == self.actions.batch_shape

if self.log_rewards is not None:
assert len(self.log_rewards) == self.states.batch_shape[-1]
assert len(self.log_rewards) == self.actions.batch_shape[-1]

# Either set, or append, estimator outputs if they exist in the submitted
# trajectory.
Expand Down

0 comments on commit 7449a92

Please sign in to comment.