Skip to content

Commit

Permalink
added mps seeding
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 15, 2024
1 parent 7c1a3da commit 3041fb2
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/gfn/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ def set_seed(seed: int, performance_mode: bool = False) -> None:
np.random.seed(seed)
torch.manual_seed(seed)

if torch.backends.mps.is_available():
torch.mps.manual_seed(seed)

# These are only set when we care about reproducibility over performance.
if not performance_mode:
torch.backends.cudnn.deterministic = True
Expand Down
8 changes: 4 additions & 4 deletions tutorials/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,17 @@ class BoxArgs(CommonArgs):
@pytest.mark.parametrize("ndim", [2, 4])
@pytest.mark.parametrize("height", [8, 16])
def test_hypergrid(ndim: int, height: int):
n_trajectories = 32000 if ndim == 2 else 16000
n_trajectories = 64000 # if ndim == 2 else 16000
args = HypergridArgs(ndim=ndim, height=height, n_trajectories=n_trajectories)
final_l1_dist = train_hypergrid_main(args)
if ndim == 2 and height == 8:
assert np.isclose(final_l1_dist, 8.78e-4, atol=1e-3)
elif ndim == 2 and height == 16:
assert np.isclose(final_l1_dist, 4.56e-4, atol=1e-4)
assert np.isclose(final_l1_dist, 2.62e-4, atol=1e-3)
elif ndim == 4 and height == 8:
assert np.isclose(final_l1_dist, 1.6e-4, atol=1e-4)
assert np.isclose(final_l1_dist, 1.6e-4, atol=1e-3)
elif ndim == 4 and height == 16:
assert np.isclose(final_l1_dist, 2.45e-5, atol=1e-5)
assert np.isclose(final_l1_dist, 6.89e-6, atol=1e-5)


@pytest.mark.parametrize("ndim", [2, 4])
Expand Down

0 comments on commit 3041fb2

Please sign in to comment.