From 4e364d389e372810d9a91a9e2b5df237f9a64de9 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 15:34:41 -0500 Subject: [PATCH] black --- src/gfn/states.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/gfn/states.py b/src/gfn/states.py index 86eeabae..0e774b3b 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -417,7 +417,6 @@ def set_nonexit_action_masks(self, cond, allow_exit: bool): exit_idx = torch.ones(self.batch_shape + (1,)).to(cond.device) self.forward_masks[torch.cat([cond, exit_idx], dim=-1).bool()] = False - def set_exit_masks(self, batch_idx): """Sets forward masks such that the only allowable next action is to exit. @@ -458,14 +457,22 @@ def stack_states(states: List[States]): stacked_states = state_example.from_batch_shape((0, 0)) # Empty. stacked_states.tensor = torch.stack([s.tensor for s in states], dim=0) if state_example._log_rewards: - stacked_states._log_rewards = torch.stack([s._log_rewards for s in states], dim=0) + stacked_states._log_rewards = torch.stack( + [s._log_rewards for s in states], dim=0 + ) # We are dealing with a list of DiscretrStates instances. if hasattr(state_example, "forward_masks"): - stacked_states.forward_masks = torch.stack([s.forward_masks for s in states], dim=0) - stacked_states.backward_masks = torch.stack([s.backward_masks for s in states], dim=0) + stacked_states.forward_masks = torch.stack( + [s.forward_masks for s in states], dim=0 + ) + stacked_states.backward_masks = torch.stack( + [s.backward_masks for s in states], dim=0 + ) # Adds the trajectory dimension. - stacked_states.batch_shape = (stacked_states.tensor.shape[0],) + state_example.batch_shape + stacked_states.batch_shape = ( + stacked_states.tensor.shape[0], + ) + state_example.batch_shape return stacked_states