Skip to content

Commit

Permalink
no longer using relative import (but using yucky code duplication)
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Jun 18, 2024
1 parent c2819d5 commit fa68095
Showing 1 changed file with 32 additions and 1 deletion.
33 changes: 32 additions & 1 deletion tutorials/examples/train_hypergrid_multinode.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from gfn.modules import DiscretePolicyEstimator, ScalarEstimator
from gfn.utils.common import set_seed
from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular
from .train_hypergrid import validate_hypergrid
from gfn.utils.training import validate

DEFAULT_SEED = 4444

Expand Down Expand Up @@ -358,6 +358,37 @@ def main(args): # noqa: C901
return validation_info["l1_dist"]


def validate_hypergrid(
env,
gflownet,
n_validation_samples,
visited_terminating_states,
discovered_modes,
):
validation_info = validate( # Standard validation shared across envs.
env,
gflownet,
n_validation_samples,
visited_terminating_states,
)

# Add the mode counting metric.
states, scale = visited_terminating_states.tensor, env.scale_factor

normalized_states = ((states * scale) - (scale / 2) * (env.height - 1)).abs()

modes = torch.all(
(normalized_states > (0.3 * scale) * (env.height - 1))
& (normalized_states <= (0.4 * scale) * (env.height - 1)),
dim=-1,
)
modes_found = set([tuple(s.tolist()) for s in states[modes.bool()]])
discovered_modes.update(modes_found)
validation_info["n_modes_found"] = len(discovered_modes)

return validation_info, discovered_modes


if __name__ == "__main__":
parser = ArgumentParser()

Expand Down

0 comments on commit fa68095

Please sign in to comment.