Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Oct 28, 2024
1 parent 316a0c4 commit 8512dce
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ optimizer.add_param_group({"params": gfn.logz_parameters(), "lr": 1e-1})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n_trajectories=16)
trajectories = sampler.sample_trajectories(env=env, n=16)
optimizer.zero_grad()
loss = gfn.loss(env, trajectories)
loss.backward()
Expand Down Expand Up @@ -161,7 +161,7 @@ optimizer.add_param_group({"params": gfn.logF_parameters(), "lr": 1e-2})

# 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration
for i in (pbar := tqdm(range(1000))):
trajectories = sampler.sample_trajectories(env=env, n_trajectories=16)
trajectories = sampler.sample_trajectories(env=env, n=16)
optimizer.zero_grad()
loss = gfn.loss(env, trajectories)
loss.backward()
Expand Down
2 changes: 1 addition & 1 deletion tutorials/examples/train_hypergrid_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def main(args):
for i in (pbar := tqdm(range(args.n_iterations))):
trajectories = sampler.sample_trajectories(
env,
n_trajectories=args.batch_size,
n=args.batch_size,
save_logprobs=False,
save_estimator_outputs=True,
epsilon=args.epsilon,
Expand Down

0 comments on commit 8512dce

Please sign in to comment.