Skip to content

Commit

Permalink
replaced inappropriate env. with self.
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 13, 2024
1 parent 406bb3d commit 2cf109c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tutorials/notebooks/intro_gfn_smiley.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1708,7 +1708,7 @@
" # Then, we remove any done action, and also the exit action.\n",
" states.set_nonexit_action_masks(states.tensor == 1, allow_exit=False)\n",
"\n",
" if env.mask_invalid_actions:\n",
" if self.mask_invalid_actions:\n",
" # Now we remove invalid actions. Here we are enforcing that\n",
" # only one left eyebrow, one right eyebrow, and one smile can be\n",
" # selected. 0 = not allowed.\n",
Expand Down Expand Up @@ -1746,7 +1746,7 @@
" + 1 if the face is frowny. :(\n",
" + 2 if the face is smiley. :)\n",
" \"\"\"\n",
" if not env.mask_invalid_actions:\n",
" if not self.mask_invalid_actions:\n",
" # Tensor organization is [left_eb *2, right_eb * 2, mouth * 2]\n",
" valid = torch.zeros(states.batch_shape + (3,))\n",
" valid[..., 0] = states.tensor[..., :2].sum(-1) == 1 # One left eyebrow.\n",
Expand All @@ -1759,7 +1759,7 @@
" rewards[states.tensor[..., 4] == 1] = torch.tensor([2]) # Smiles.\n",
" rewards[states.tensor[..., 5] == 1] = torch.tensor([1]) # Frowns.\n",
"\n",
" if not env.mask_invalid_actions:\n",
" if not self.mask_invalid_actions:\n",
" rewards = rewards * valid # This will remove any double mouths.\n",
"\n",
" return rewards.squeeze()\n",
Expand Down

0 comments on commit 2cf109c

Please sign in to comment.