Skip to content

Commit

Permalink
NOT WORKING: this commit contains trajectories_states_b which is the …
Browse files Browse the repository at this point in the history
…proposed new method for stacking a list of states into a trajectory, but as the assert statements show, the tensor is correct, but the forward_masks are not
  • Loading branch information
josephdviviano committed Feb 22, 2024
1 parent 2afb00e commit be2fee1
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from gfn.containers import Trajectories
from gfn.env import Env
from gfn.modules import GFNModule
from gfn.states import States
from gfn.states import States, stack_states


class Sampler:
Expand Down Expand Up @@ -140,6 +140,8 @@ def sample_trajectories(
else states.is_sink_state
)

trajectories_states_b: List[States] = [states]

trajectories_states: List[TT["n_trajectories", "state_shape", torch.float]] = [
states.tensor
]
Expand Down Expand Up @@ -220,9 +222,18 @@ def sample_trajectories(
dones = dones | new_dones

trajectories_states += [states.tensor]
trajectories_states_b += [states]

# New Method
trajectories_states_b = stack_states(trajectories_states_b)

# Old Method
trajectories_states = env.states_from_tensor(
torch.stack(trajectories_states, dim=0))

assert (trajectories_states_b.tensor == trajectories_states.tensor).sum() == trajectories_states.tensor.numel()
assert (trajectories_states_b.forward_masks == trajectories_states.forward_masks).sum() == trajectories_states.forward_masks.numel()

trajectories_states = torch.stack(trajectories_states, dim=0)
trajectories_states = env.states_from_tensor(trajectories_states)
trajectories_actions = env.Actions.stack(trajectories_actions)
trajectories_logprobs = torch.stack(trajectories_logprobs, dim=0)

Expand Down

0 comments on commit be2fee1

Please sign in to comment.