From 421a16775f7d01a38b8a94a71da525b4b28850aa Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 18 Jul 2024 17:45:20 -0400 Subject: [PATCH] fixed hypergrid tests --- testing/test_environments.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/testing/test_environments.py b/testing/test_environments.py index 5dbd4cc6..6b0fede8 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -298,21 +298,23 @@ def test_get_grid(): HEIGHT = 8 NDIM = 2 - env = HyperGrid(height=HEIGHT, ndim=NDIM) - grid = env.build_grid() + env = HyperGrid( + height=HEIGHT, ndim=NDIM, calculate_all_states=True, calculate_partition=True + ) + all_states = env.all_states - assert grid.batch_shape == (HEIGHT, HEIGHT) - assert grid.state_shape == (NDIM,) + assert all_states.batch_shape == (HEIGHT**2,) + assert all_states.state_shape == (NDIM,) - rewards = env.reward(grid) - assert tuple(rewards.shape) == grid.batch_shape + rewards = env.reward(all_states) + assert tuple(rewards.shape) == all_states.batch_shape # All rewards are positive. assert torch.sum(rewards > 0) == HEIGHT**2 # log(Z) should equal the environment log_partition. Z = rewards.sum() - assert Z.log().item() == env.log_partition + assert np.isclose(Z.log().item(), env.log_partition) # State indices of the grid are ordered from 0:HEIGHT**2. - assert (env.get_states_indices(grid).ravel() == torch.arange(HEIGHT**2)).all() + assert (env.get_states_indices(all_states).ravel() == torch.arange(HEIGHT**2)).all()