From 406bb3d1f8fa817cb6247c8f0c0710e8f210abf2 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Mon, 11 Nov 2024 09:16:04 -0500 Subject: [PATCH] removed obselete states class factory --- tutorials/notebooks/intro_gfn_smiley.ipynb | 45 +--------------------- 1 file changed, 1 insertion(+), 44 deletions(-) diff --git a/tutorials/notebooks/intro_gfn_smiley.ipynb b/tutorials/notebooks/intro_gfn_smiley.ipynb index 2b5a1e78..8a6ad63b 100644 --- a/tutorials/notebooks/intro_gfn_smiley.ipynb +++ b/tutorials/notebooks/intro_gfn_smiley.ipynb @@ -1658,7 +1658,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": { "id": "92RxW4V7aLk7" }, @@ -1697,49 +1697,6 @@ " preprocessor=IdentityPreprocessor(output_dim=state_dim)\n", " )\n", "\n", - " # def make_States_class(self) -> type[DiscreteStates]:\n", - " # \"Creates a States class for this environment\"\n", - " # env = self\n", - "\n", - " # class FaceStates(DiscreteStates):\n", - " # state_shape: ClassVar[tuple[int, ...]] = (env.n_actions - 1,) # this is 6.\n", - " # s0 = env.s0 # this is 6 zeros.\n", - " # sf = env.sf # this is 6 -1's.\n", - " # n_actions = env.n_actions # this is 7. 6 features, plus 1 for exit.\n", - " # device = env.device\n", - "\n", - " # def update_masks(self) -> None:\n", - " # \"Update the masks based on the current states.\"\n", - " # # Backward masks are simply any action we've already taken.\n", - " # self.backward_masks = self.tensor != 0 # n - 1 actions.\n", - "\n", - " # # Forward masks begin as allowing any action. Allowed actions are 1.\n", - " # self.init_forward_masks(set_ones=True)\n", - "\n", - " # # Then, we remove any done action, and also the exit action.\n", - " # self.set_nonexit_masks(self.tensor == 1, allow_exit=False)\n", - "\n", - " # if env.mask_invalid_actions:\n", - " # # Now we remove invalid actions. Here we are enforcing that\n", - " # # only one left eyebrow, one right eyebrow, and one smile can be\n", - " # # selected. 0 = not allowed.\n", - " # invalid_actions = torch.ones(self.forward_masks.shape).bool()\n", - " # invalid_actions[..., 0][self.tensor[..., 1].bool()] = 0 # l_eb\n", - " # invalid_actions[..., 1][self.tensor[..., 0].bool()] = 0 # l_eb\n", - " # invalid_actions[..., 2][self.tensor[..., 3].bool()] = 0 # r_eb\n", - " # invalid_actions[..., 3][self.tensor[..., 2].bool()] = 0 # r_eb\n", - " # invalid_actions[..., 4][self.tensor[..., 5].bool()] = 0 # smile\n", - " # invalid_actions[..., 5][self.tensor[..., 4].bool()] = 0 # smile\n", - "\n", - " # self.forward_masks = (self.forward_masks * invalid_actions)\n", - "\n", - " # # Trajectories must be length 3. Any trajectory that has taken 3 actions\n", - " # # should be forced to exit.\n", - " # batch_idx = self.tensor.sum(-1) >= 3\n", - " # self.set_exit_masks(batch_idx)\n", - "\n", - " # return FaceStates\n", - "\n", " def update_masks(self, states: type[DiscreteStates]) -> None:\n", " \"Update the masks based on the current states.\"\n", " # Backward masks are simply any action we've already taken.\n",