Skip to content

Commit

Permalink
fixed hypergrid tests
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Jul 18, 2024
1 parent f7e5f92 commit 421a167
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions testing/test_environments.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,21 +298,23 @@ def test_get_grid():
HEIGHT = 8
NDIM = 2

env = HyperGrid(height=HEIGHT, ndim=NDIM)
grid = env.build_grid()
env = HyperGrid(
height=HEIGHT, ndim=NDIM, calculate_all_states=True, calculate_partition=True
)
all_states = env.all_states

assert grid.batch_shape == (HEIGHT, HEIGHT)
assert grid.state_shape == (NDIM,)
assert all_states.batch_shape == (HEIGHT**2,)
assert all_states.state_shape == (NDIM,)

rewards = env.reward(grid)
assert tuple(rewards.shape) == grid.batch_shape
rewards = env.reward(all_states)
assert tuple(rewards.shape) == all_states.batch_shape

# All rewards are positive.
assert torch.sum(rewards > 0) == HEIGHT**2

# log(Z) should equal the environment log_partition.
Z = rewards.sum()
assert Z.log().item() == env.log_partition
assert np.isclose(Z.log().item(), env.log_partition)

# State indices of the grid are ordered from 0:HEIGHT**2.
assert (env.get_states_indices(grid).ravel() == torch.arange(HEIGHT**2)).all()
assert (env.get_states_indices(all_states).ravel() == torch.arange(HEIGHT**2)).all()

0 comments on commit 421a167

Please sign in to comment.