diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 2041c7ca..efef2645 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -17,7 +17,7 @@ import wandb from tqdm import tqdm, trange -from gfn.containers import ReplayBuffer +from gfn.containers import ReplayBuffer, PrioritizedReplayBuffer from gfn.gflownet import ( DBGFlowNet, FMGFlowNet, @@ -185,12 +185,17 @@ def main(args): # noqa: C901 objects_type = "states" else: raise NotImplementedError(f"Unknown loss: {args.loss}") - replay_buffer = ReplayBuffer( - env, objects_type=objects_type, capacity=args.replay_buffer_size - ) - # 3. Create the optimizer + if args.replay_buffer_prioritized: + replay_buffer = PrioritizedReplayBuffer( + env, objects_type=objects_type, capacity=args.replay_buffer_size + ) + else: + replay_buffer = ReplayBuffer( + env, objects_type=objects_type, capacity=args.replay_buffer_size + ) + # 3. Create the optimizer # Policy parameters have their own LR. params = [ { @@ -292,6 +297,11 @@ def main(args): # noqa: C901 default=0, help="If zero, no replay buffer is used. Otherwise, the replay buffer is used.", ) + parser.add_argument( + "--replay_buffer_prioritized", + action="store_true", + help="If set and replay_buffer_size > 0, use a prioritized replay buffer.", + ) parser.add_argument( "--loss",