diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 68b052a6..f22fa06e 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -7,7 +7,7 @@ from gfn.containers import Trajectories from gfn.env import Env from gfn.modules import GFNModule -from gfn.states import States +from gfn.states import States, stack_states class Sampler: @@ -140,6 +140,8 @@ def sample_trajectories( else states.is_sink_state ) + trajectories_states_b: List[States] = [states] + trajectories_states: List[TT["n_trajectories", "state_shape", torch.float]] = [ states.tensor ] @@ -220,9 +222,18 @@ def sample_trajectories( dones = dones | new_dones trajectories_states += [states.tensor] + trajectories_states_b += [states] + + # New Method + trajectories_states_b = stack_states(trajectories_states_b) + + # Old Method + trajectories_states = env.states_from_tensor( + torch.stack(trajectories_states, dim=0)) + + assert (trajectories_states_b.tensor == trajectories_states.tensor).sum() == trajectories_states.tensor.numel() + assert (trajectories_states_b.forward_masks == trajectories_states.forward_masks).sum() == trajectories_states.forward_masks.numel() - trajectories_states = torch.stack(trajectories_states, dim=0) - trajectories_states = env.states_from_tensor(trajectories_states) trajectories_actions = env.Actions.stack(trajectories_actions) trajectories_logprobs = torch.stack(trajectories_logprobs, dim=0)