diff --git a/tutorials/examples/train_hypergrid_multinode.py b/tutorials/examples/train_hypergrid_multinode.py index b6eac423..98e1e4ff 100644 --- a/tutorials/examples/train_hypergrid_multinode.py +++ b/tutorials/examples/train_hypergrid_multinode.py @@ -373,6 +373,7 @@ def validate_hypergrid( ) # Add the mode counting metric. + # TODO: This is not the same as what is done in `train_hypergrid`. states, scale = visited_terminating_states.tensor, env.scale_factor normalized_states = ((states * scale) - (scale / 2) * (env.height - 1)).abs()