Skip to content

Commit

Permalink
removed obselete states class factory
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 11, 2024
1 parent a7936d9 commit 406bb3d
Showing 1 changed file with 1 addition and 44 deletions.
45 changes: 1 addition & 44 deletions tutorials/notebooks/intro_gfn_smiley.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1658,7 +1658,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": null,
"metadata": {
"id": "92RxW4V7aLk7"
},
Expand Down Expand Up @@ -1697,49 +1697,6 @@
" preprocessor=IdentityPreprocessor(output_dim=state_dim)\n",
" )\n",
"\n",
" # def make_States_class(self) -> type[DiscreteStates]:\n",
" # \"Creates a States class for this environment\"\n",
" # env = self\n",
"\n",
" # class FaceStates(DiscreteStates):\n",
" # state_shape: ClassVar[tuple[int, ...]] = (env.n_actions - 1,) # this is 6.\n",
" # s0 = env.s0 # this is 6 zeros.\n",
" # sf = env.sf # this is 6 -1's.\n",
" # n_actions = env.n_actions # this is 7. 6 features, plus 1 for exit.\n",
" # device = env.device\n",
"\n",
" # def update_masks(self) -> None:\n",
" # \"Update the masks based on the current states.\"\n",
" # # Backward masks are simply any action we've already taken.\n",
" # self.backward_masks = self.tensor != 0 # n - 1 actions.\n",
"\n",
" # # Forward masks begin as allowing any action. Allowed actions are 1.\n",
" # self.init_forward_masks(set_ones=True)\n",
"\n",
" # # Then, we remove any done action, and also the exit action.\n",
" # self.set_nonexit_masks(self.tensor == 1, allow_exit=False)\n",
"\n",
" # if env.mask_invalid_actions:\n",
" # # Now we remove invalid actions. Here we are enforcing that\n",
" # # only one left eyebrow, one right eyebrow, and one smile can be\n",
" # # selected. 0 = not allowed.\n",
" # invalid_actions = torch.ones(self.forward_masks.shape).bool()\n",
" # invalid_actions[..., 0][self.tensor[..., 1].bool()] = 0 # l_eb\n",
" # invalid_actions[..., 1][self.tensor[..., 0].bool()] = 0 # l_eb\n",
" # invalid_actions[..., 2][self.tensor[..., 3].bool()] = 0 # r_eb\n",
" # invalid_actions[..., 3][self.tensor[..., 2].bool()] = 0 # r_eb\n",
" # invalid_actions[..., 4][self.tensor[..., 5].bool()] = 0 # smile\n",
" # invalid_actions[..., 5][self.tensor[..., 4].bool()] = 0 # smile\n",
"\n",
" # self.forward_masks = (self.forward_masks * invalid_actions)\n",
"\n",
" # # Trajectories must be length 3. Any trajectory that has taken 3 actions\n",
" # # should be forced to exit.\n",
" # batch_idx = self.tensor.sum(-1) >= 3\n",
" # self.set_exit_masks(batch_idx)\n",
"\n",
" # return FaceStates\n",
"\n",
" def update_masks(self, states: type[DiscreteStates]) -> None:\n",
" \"Update the masks based on the current states.\"\n",
" # Backward masks are simply any action we've already taken.\n",
Expand Down

0 comments on commit 406bb3d

Please sign in to comment.