Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Feb 24, 2024
1 parent 77e7e1b commit 4e364d3
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions src/gfn/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,6 @@ def set_nonexit_action_masks(self, cond, allow_exit: bool):
exit_idx = torch.ones(self.batch_shape + (1,)).to(cond.device)
self.forward_masks[torch.cat([cond, exit_idx], dim=-1).bool()] = False


def set_exit_masks(self, batch_idx):
"""Sets forward masks such that the only allowable next action is to exit.
Expand Down Expand Up @@ -458,14 +457,22 @@ def stack_states(states: List[States]):
stacked_states = state_example.from_batch_shape((0, 0)) # Empty.
stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0)
if state_example._log_rewards:
stacked_states._log_rewards = torch.stack([s._log_rewards for s in states], dim=0)
stacked_states._log_rewards = torch.stack(
[s._log_rewards for s in states], dim=0
)

# We are dealing with a list of DiscretrStates instances.
if hasattr(state_example, "forward_masks"):
stacked_states.forward_masks = torch.stack([s.forward_masks for s in states], dim=0)
stacked_states.backward_masks = torch.stack([s.backward_masks for s in states], dim=0)
stacked_states.forward_masks = torch.stack(
[s.forward_masks for s in states], dim=0
)
stacked_states.backward_masks = torch.stack(
[s.backward_masks for s in states], dim=0
)

# Adds the trajectory dimension.
stacked_states.batch_shape = (stacked_states.tensor.shape[0],) + state_example.batch_shape
stacked_states.batch_shape = (
stacked_states.tensor.shape[0],
) + state_example.batch_shape

return stacked_states

0 comments on commit 4e364d3

Please sign in to comment.