From 7449a923d535c13d5477cf938edc35f403bc16e0 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sun, 31 Mar 2024 12:27:24 -0400 Subject: [PATCH] bugfix on assert --- src/gfn/containers/trajectories.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 6002a330..d52274f8 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -255,10 +255,10 @@ def extend(self, other: Trajectories) -> None: # Ensure log_probs/rewards are the correct dimensions. TODO: Remove? if self.log_probs.numel() > 0: - assert len(self.log_probs) == self.states.batch_shape[-1] + assert self.log_probs.shape == self.actions.batch_shape if self.log_rewards is not None: - assert len(self.log_rewards) == self.states.batch_shape[-1] + assert len(self.log_rewards) == self.actions.batch_shape[-1] # Either set, or append, estimator outputs if they exist in the submitted # trajectory.