From c57a708f063b0cd66d134c190d0a0dd98f6de609 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 18 Jun 2024 16:14:18 -0400 Subject: [PATCH] added mode counting --- tutorials/examples/train_hypergrid.py | 36 ++++++++++++++++++- .../examples/train_hypergrid_multinode.py | 7 ++-- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 702ba4fc..08918c18 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -226,6 +226,8 @@ def main(args): # noqa: C901 states_visited = 0 n_iterations = args.n_trajectories // args.batch_size validation_info = {"l1_dist": float("inf")} + discovered_modes = set() + for iteration in trange(n_iterations): trajectories = gflownet.sample_trajectories( env, @@ -254,11 +256,12 @@ def main(args): # noqa: C901 if use_wandb: wandb.log(to_log, step=iteration) if (iteration % args.validation_interval == 0) or (iteration == n_iterations - 1): - validation_info = validate( + validation_info, discovered_modes = validate_hypergrid( env, gflownet, args.validation_samples, visited_terminating_states, + discovered_modes, ) if use_wandb: wandb.log(validation_info, step=iteration) @@ -268,6 +271,37 @@ def main(args): # noqa: C901 return validation_info["l1_dist"] +def validate_hypergrid( + env, + gflownet, + n_validation_samples, + visited_terminating_states, + discovered_modes, +): + validation_info = validate( # Standard validation shared across envs. + env, + gflownet, + n_validation_samples, + visited_terminating_states, + ) + + # Add the mode counting metric. + states, scale = visited_terminating_states.tensor, env.scale_factor + + normalized_states = ((states * scale) - (scale / 2) * (env.height - 1)).abs() + + modes = torch.all( + (normalized_states > (0.3 * scale) * (env.height - 1)) + & (normalized_states <= (0.4 * scale) * (env.height - 1)), + dim=-1, + ) + modes_found = set([tuple(s.tolist()) for s in states[modes.bool()]]) + discovered_modes.update(modes_found) + validation_info["n_modes_found"] = len(discovered_modes) + + return validation_info, discovered_modes + + if __name__ == "__main__": parser = ArgumentParser() diff --git a/tutorials/examples/train_hypergrid_multinode.py b/tutorials/examples/train_hypergrid_multinode.py index 05e50845..379afa7f 100644 --- a/tutorials/examples/train_hypergrid_multinode.py +++ b/tutorials/examples/train_hypergrid_multinode.py @@ -37,7 +37,7 @@ from gfn.modules import DiscretePolicyEstimator, ScalarEstimator from gfn.utils.common import set_seed from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular -from gfn.utils.training import validate +from .train_hypergrid import validate_hypergrid DEFAULT_SEED = 4444 @@ -292,6 +292,8 @@ def main(args): # noqa: C901 print ("n_iterations = ", n_iterations) print ("my_batch_size = ", my_batch_size) time_start = time.time() + discovered_modes = set() + for iteration in trange(n_iterations): sample_start = time.time() trajectories = gflownet.sample_trajectories( @@ -333,11 +335,12 @@ def main(args): # noqa: C901 if use_wandb: wandb.log(to_log, step=iteration) if (iteration % args.validation_interval == 0) or (iteration == n_iterations - 1): - validation_info = validate( + validation_info, discovered_modes = validate_hypergrid( env, gflownet, args.validation_samples, visited_terminating_states, + discovered_modes, ) if use_wandb: wandb.log(validation_info, step=iteration)