Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

small fixes to seeding and hypergrid tests, docstring improvements to resolve confusion of get_terminating_states_indices purpose #217

Merged
merged 3 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def get_states_indices(self, states: DiscreteStates) -> torch.Tensor:
return states_indices

def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""Returns the indices of the terminating states.
"""Get the indices of the terminating states in the canonical ordering from the submitted states.

Args:
states: DiscreteStates object representing the states.
Expand Down
5 changes: 3 additions & 2 deletions src/gfn/gym/hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def update_masks(self, states: type[DiscreteStates]) -> None:
"""Update the masks based on the current states."""
# Not allowed to take any action beyond the environment height, but
# allow early termination.
# TODO: do we need to handle the conditional case here?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. We may need another EnvBase that incorporates conditions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConditionalEnvs are interesting. I should actually do this ASAP so we have a working conditional GFN example.

states.set_nonexit_action_masks(
states.tensor == self.height - 1,
allow_exit=True,
Expand Down Expand Up @@ -174,9 +175,9 @@ def get_states_indices(self, states: DiscreteStates) -> torch.Tensor:
return indices

def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor:
"""Get the indices of the terminating states in the canonical ordering.
"""Get the indices of the terminating states in the canonical ordering from the submitted states.

Returns the indices of the terminating states in the canonical ordering as a tensor of shape `batch_shape`.
Canonical ordering is returned as a tensor of shape `batch_shape`.
"""
return self.get_states_indices(states)

Expand Down
3 changes: 3 additions & 0 deletions src/gfn/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ def set_seed(seed: int, performance_mode: bool = False) -> None:
np.random.seed(seed)
torch.manual_seed(seed)

if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)

# These are only set when we care about reproducibility over performance.
if not performance_mode:
torch.backends.cudnn.deterministic = True
Expand Down
8 changes: 4 additions & 4 deletions tutorials/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,17 @@ class BoxArgs(CommonArgs):
@pytest.mark.parametrize("ndim", [2, 4])
@pytest.mark.parametrize("height", [8, 16])
def test_hypergrid(ndim: int, height: int):
n_trajectories = 32000 if ndim == 2 else 16000
n_trajectories = 64000 # if ndim == 2 else 16000
args = HypergridArgs(ndim=ndim, height=height, n_trajectories=n_trajectories)
final_l1_dist = train_hypergrid_main(args)
if ndim == 2 and height == 8:
assert np.isclose(final_l1_dist, 8.78e-4, atol=1e-3)
elif ndim == 2 and height == 16:
assert np.isclose(final_l1_dist, 4.56e-4, atol=1e-4)
assert np.isclose(final_l1_dist, 2.62e-4, atol=1e-3)
elif ndim == 4 and height == 8:
assert np.isclose(final_l1_dist, 1.6e-4, atol=1e-4)
assert np.isclose(final_l1_dist, 1.6e-4, atol=1e-3)
elif ndim == 4 and height == 16:
assert np.isclose(final_l1_dist, 2.45e-5, atol=1e-5)
assert np.isclose(final_l1_dist, 6.89e-6, atol=1e-5)


@pytest.mark.parametrize("ndim", [2, 4])
Expand Down
Loading