diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 251b5921..ff8e7852 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -16,6 +16,11 @@ from gfn.containers.transitions import Transitions +def is_tensor(t) -> bool: + """Checks whether t is a torch.Tensor instance.""" + return isinstance(t, Tensor) + + # TODO: remove env from this class? class Trajectories(Container): """Container for complete trajectories (starting in $s_0$ and ending in $s_f$). diff --git a/src/gfn/env.py b/src/gfn/env.py index 588b2553..9b045ca3 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -218,9 +218,7 @@ def _step( not_done_states = new_states[~new_sink_states_idx] not_done_actions = actions[~new_sink_states_idx] - new_not_done_states_tensor = self.maskless_step( - not_done_states, not_done_actions - ) + new_not_done_states_tensor = self.step(not_done_states, not_done_actions) new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor