From fe9a2484c3c96a1afd49f6b15eaaa6a0edfa8aee Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 18 Jul 2024 19:40:08 -0400 Subject: [PATCH] added required flags --- tutorials/examples/train_hypergrid.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 27a6ec7d..cb9e29ad 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -42,12 +42,20 @@ def main(args): # noqa: C901 use_wandb = len(args.wandb_project) > 0 if use_wandb: import wandb + wandb.init(project=args.wandb_project) wandb.config.update(args) # 1. Create the environment env = HyperGrid( - args.ndim, args.height, args.R0, args.R1, args.R2, device_str=device_str + args.ndim, + args.height, + args.R0, + args.R1, + args.R2, + device_str=device_str, + calculate_partition=args.calculate_partition, + calculate_all_states=args.calculate_all_states, ) # 2. Create the gflownets. @@ -255,7 +263,9 @@ def main(args): # noqa: C901 to_log = {"loss": loss.item(), "states_visited": states_visited} if use_wandb: wandb.log(to_log, step=iteration) - if (iteration % args.validation_interval == 0) or (iteration == n_iterations - 1): + if (iteration % args.validation_interval == 0) or ( + iteration == n_iterations - 1 + ): validation_info, discovered_modes = validate_hypergrid( env, gflownet, @@ -271,7 +281,6 @@ def main(args): # noqa: C901 try: return validation_info["l1_dist"] except KeyError: - print(validation_info.keys()) return validation_info["n_modes_found"] @@ -291,7 +300,7 @@ def validate_hypergrid( # # Add the mode counting metric. states, scale = visited_terminating_states.tensor, env.scale_factor - mode_reward_threshold = 1. # Assumes height >= 5. TODO - verify. + mode_reward_threshold = 1.0 # Assumes height >= 5. TODO - verify. # # Modes will have a reward greater than 1. modes = states[env.reward(states) >= mode_reward_threshold]