Skip to content

Commit

Permalink
can use either standard or prioritized replay buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Mar 30, 2024
1 parent e087f41 commit 61f7fd2
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import wandb
from tqdm import tqdm, trange

from gfn.containers import ReplayBuffer
from gfn.containers import ReplayBuffer, PrioritizedReplayBuffer
from gfn.gflownet import (
DBGFlowNet,
FMGFlowNet,
Expand Down Expand Up @@ -185,12 +185,17 @@ def main(args): # noqa: C901
objects_type = "states"
else:
raise NotImplementedError(f"Unknown loss: {args.loss}")
replay_buffer = ReplayBuffer(
env, objects_type=objects_type, capacity=args.replay_buffer_size
)

# 3. Create the optimizer
if args.replay_buffer_prioritized:
replay_buffer = PrioritizedReplayBuffer(
env, objects_type=objects_type, capacity=args.replay_buffer_size
)
else:
replay_buffer = ReplayBuffer(
env, objects_type=objects_type, capacity=args.replay_buffer_size
)

# 3. Create the optimizer
# Policy parameters have their own LR.
params = [
{
Expand Down Expand Up @@ -292,6 +297,11 @@ def main(args): # noqa: C901
default=0,
help="If zero, no replay buffer is used. Otherwise, the replay buffer is used.",
)
parser.add_argument(
"--replay_buffer_prioritized",
action="store_true",
help="If set and replay_buffer_size > 0, use a prioritized replay buffer.",
)

parser.add_argument(
"--loss",
Expand Down

0 comments on commit 61f7fd2

Please sign in to comment.