From 8481673495314ffea4462822c80479abca134f43 Mon Sep 17 00:00:00 2001 From: hyeok9855 Date: Wed, 27 Nov 2024 03:12:06 +0900 Subject: [PATCH] fix a minor error in the Trajectory.__repr__ --- src/gfn/containers/trajectories.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 5feb665a..6363c925 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -122,7 +122,7 @@ def __repr__(self) -> str: for traj in states[:10]: one_traj_repr = [] for step in traj: - one_traj_repr.append(str(step.numpy())) + one_traj_repr.append(str(step.cpu().numpy())) if step.equal(self.env.s0 if self.is_backward else self.env.sf): break trajectories_representation += "-> ".join(one_traj_repr) + "\n" @@ -130,7 +130,7 @@ def __repr__(self) -> str: f"Trajectories(n_trajectories={self.n_trajectories}, max_length={self.max_length}, First 10 trajectories:" + f"states=\n{trajectories_representation}" # + f"actions=\n{self.actions.tensor.squeeze().transpose(0, 1)[:10].numpy()}, " - + f"when_is_done={self.when_is_done[:10].numpy()})" + + f"when_is_done={self.when_is_done[:10].cpu().numpy()})" ) @property