diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 0ea3e913..5a3cf8dd 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -6,7 +6,6 @@ python train_box.py --delta {0.1, 0.25} --tied {--uniform_pb} --loss {TB, DB} """ - from argparse import ArgumentParser import numpy as np @@ -233,9 +232,7 @@ def main(args): # noqa: C901 print(f"current optimizer LR: {optimizer.param_groups[0]['lr']}") trajectories = gflownet.sample_trajectories( - env, - sample_off_policy=False, - n_samples=args.batch_size + env, sample_off_policy=False, n_samples=args.batch_size ) training_samples = gflownet.to_training_samples(trajectories) diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index 68b1ba9f..33aa1cc8 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -10,7 +10,6 @@ [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid.py --ndim {2, 4} --height 12 --R0 {1e-3, 1e-4} --tied --loss {TB, DB, SubTB} """ - from argparse import ArgumentParser import torch @@ -20,10 +19,9 @@ from gfn.gflownet import FMGFlowNet from gfn.gym import DiscreteEBM from gfn.modules import DiscretePolicyEstimator -from gfn.utils.common import validate -from gfn.utils.modules import NeuralNet, Tabular - from gfn.utils.common import set_seed +from gfn.utils.modules import NeuralNet, Tabular +from gfn.utils.training import validate DEFAULT_SEED = 4444 @@ -72,9 +70,7 @@ def main(args): # noqa: C901 validation_info = {"l1_dist": float("inf")} for iteration in trange(n_iterations): trajectories = gflownet.sample_trajectories( - env, - off_policy=False, - n_samples=args.batch_size + env, off_policy=False, n_samples=args.batch_size ) training_samples = gflownet.to_training_samples(trajectories) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 113df50f..e3301cdd 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -10,7 +10,6 @@ [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid.py --ndim {2, 4} --height 12 --R0 {1e-3, 1e-4} --tied --loss {TB, DB, SubTB} """ - from argparse import ArgumentParser import torch @@ -28,10 +27,9 @@ ) from gfn.gym import HyperGrid from gfn.modules import DiscretePolicyEstimator, ScalarEstimator -from gfn.utils.common import validate -from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular - from gfn.utils.common import set_seed +from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular +from gfn.utils.training import validate DEFAULT_SEED = 4444 @@ -225,7 +223,9 @@ def main(args): # noqa: C901 n_iterations = args.n_trajectories // args.batch_size validation_info = {"l1_dist": float("inf")} for iteration in trange(n_iterations): - trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size, sample_off_policy=off_policy_sampling) + trajectories = gflownet.sample_trajectories( + env, n_samples=args.batch_size, sample_off_policy=off_policy_sampling + ) training_samples = gflownet.to_training_samples(trajectories) if replay_buffer is not None: with torch.no_grad(): diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py index 3d0042e5..645a6f06 100644 --- a/tutorials/examples/train_line.py +++ b/tutorials/examples/train_line.py @@ -1,4 +1,3 @@ -import random from typing import ClassVar, Literal, Tuple import matplotlib.pyplot as plt @@ -15,7 +14,6 @@ from gfn.modules import GFNModule from gfn.states import States from gfn.utils import NeuralNet - from gfn.utils.common import set_seed @@ -215,7 +213,9 @@ def log_prob(self, sampled_actions): actions_to_eval[~exit_idx] = sampled_actions[~exit_idx] if sum(~exit_idx) > 0: - logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[~exit_idx].unsqueeze(-1) + logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[ + ~exit_idx + ].unsqueeze(-1) return logprobs.squeeze(-1) @@ -289,6 +289,7 @@ def to_probability_distribution( n_steps=self.n_steps_per_trajectory, ) + def train( gflownet, env, @@ -322,7 +323,6 @@ def train( scale_schedule = np.linspace(exploration_var_starting_val, 0, n_iterations) for iteration in tbar: - optimizer.zero_grad() # Off Policy Sampling. trajectories = gflownet.sample_trajectories( @@ -361,7 +361,6 @@ def train( if __name__ == "__main__": - environment = Line( mus=[2, 5], sigmas=[0.5, 0.5],