Skip to content

Commit

Permalink
added required flags
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Jul 18, 2024
1 parent be43e32 commit fe9a248
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,20 @@ def main(args): # noqa: C901
use_wandb = len(args.wandb_project) > 0
if use_wandb:
import wandb

wandb.init(project=args.wandb_project)
wandb.config.update(args)

# 1. Create the environment
env = HyperGrid(
args.ndim, args.height, args.R0, args.R1, args.R2, device_str=device_str
args.ndim,
args.height,
args.R0,
args.R1,
args.R2,
device_str=device_str,
calculate_partition=args.calculate_partition,
calculate_all_states=args.calculate_all_states,
)

# 2. Create the gflownets.
Expand Down Expand Up @@ -255,7 +263,9 @@ def main(args): # noqa: C901
to_log = {"loss": loss.item(), "states_visited": states_visited}
if use_wandb:
wandb.log(to_log, step=iteration)
if (iteration % args.validation_interval == 0) or (iteration == n_iterations - 1):
if (iteration % args.validation_interval == 0) or (
iteration == n_iterations - 1
):
validation_info, discovered_modes = validate_hypergrid(
env,
gflownet,
Expand All @@ -271,7 +281,6 @@ def main(args): # noqa: C901
try:
return validation_info["l1_dist"]
except KeyError:
print(validation_info.keys())
return validation_info["n_modes_found"]


Expand All @@ -291,7 +300,7 @@ def validate_hypergrid(

# # Add the mode counting metric.
states, scale = visited_terminating_states.tensor, env.scale_factor
mode_reward_threshold = 1. # Assumes height >= 5. TODO - verify.
mode_reward_threshold = 1.0 # Assumes height >= 5. TODO - verify.

# # Modes will have a reward greater than 1.
modes = states[env.reward(states) >= mode_reward_threshold]
Expand Down

0 comments on commit fe9a248

Please sign in to comment.