Skip to content

Commit

Permalink
validation function is the same in both implementations now
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Jul 17, 2024
1 parent c567a37 commit f7e5f92
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions tutorials/examples/train_hypergrid_multinode.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,21 +376,30 @@ def validate_hypergrid(
visited_terminating_states,
)

# Add the mode counting metric.
# TODO: This is not the same as what is done in `train_hypergrid`.
# # Add the mode counting metric.
states, scale = visited_terminating_states.tensor, env.scale_factor
mode_reward_threshold = 1. # Assumes height >= 5. TODO - verify.

normalized_states = ((states * scale) - (scale / 2) * (env.height - 1)).abs()

modes = torch.all(
(normalized_states > (0.3 * scale) * (env.height - 1))
& (normalized_states <= (0.4 * scale) * (env.height - 1)),
dim=-1,
)
modes_found = set([tuple(s.tolist()) for s in states[modes.bool()]])
# # Modes will have a reward greater than 1.
modes = states[env.reward(states) >= mode_reward_threshold]
modes_found = set([tuple(s.tolist()) for s in modes])
discovered_modes.update(modes_found)
validation_info["n_modes_found"] = len(discovered_modes)

# Old way of counting modes -- potentially buggy - to be removed.
# # Add the mode counting metric.
# states, scale = visited_terminating_states.tensor, env.scale_factor
# normalized_states = ((states * scale) - (scale / 2) * (env.height - 1)).abs()

# modes = torch.all(
# (normalized_states > (0.3 * scale) * (env.height - 1))
# & (normalized_states <= (0.4 * scale) * (env.height - 1)),
# dim=-1,
# )
# modes_found = set([tuple(s.tolist()) for s in states[modes.bool()]])
# discovered_modes.update(modes_found)
# validation_info["n_modes_found"] = len(discovered_modes)

return validation_info, discovered_modes


Expand Down

0 comments on commit f7e5f92

Please sign in to comment.