Skip to content

Commit

Permalink
isort / black
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Feb 14, 2024
1 parent cfc560c commit 2bebde2
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 21 deletions.
5 changes: 1 addition & 4 deletions tutorials/examples/train_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
python train_box.py --delta {0.1, 0.25} --tied {--uniform_pb} --loss {TB, DB}
"""

from argparse import ArgumentParser

import numpy as np
Expand Down Expand Up @@ -233,9 +232,7 @@ def main(args): # noqa: C901
print(f"current optimizer LR: {optimizer.param_groups[0]['lr']}")

trajectories = gflownet.sample_trajectories(
env,
sample_off_policy=False,
n_samples=args.batch_size
env, sample_off_policy=False, n_samples=args.batch_size
)

training_samples = gflownet.to_training_samples(trajectories)
Expand Down
10 changes: 3 additions & 7 deletions tutorials/examples/train_discreteebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
[Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782)
python train_hypergrid.py --ndim {2, 4} --height 12 --R0 {1e-3, 1e-4} --tied --loss {TB, DB, SubTB}
"""

from argparse import ArgumentParser

import torch
Expand All @@ -20,10 +19,9 @@
from gfn.gflownet import FMGFlowNet
from gfn.gym import DiscreteEBM
from gfn.modules import DiscretePolicyEstimator
from gfn.utils.common import validate
from gfn.utils.modules import NeuralNet, Tabular

from gfn.utils.common import set_seed
from gfn.utils.modules import NeuralNet, Tabular
from gfn.utils.training import validate

DEFAULT_SEED = 4444

Expand Down Expand Up @@ -72,9 +70,7 @@ def main(args): # noqa: C901
validation_info = {"l1_dist": float("inf")}
for iteration in trange(n_iterations):
trajectories = gflownet.sample_trajectories(
env,
off_policy=False,
n_samples=args.batch_size
env, off_policy=False, n_samples=args.batch_size
)
training_samples = gflownet.to_training_samples(trajectories)

Expand Down
10 changes: 5 additions & 5 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
[Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782)
python train_hypergrid.py --ndim {2, 4} --height 12 --R0 {1e-3, 1e-4} --tied --loss {TB, DB, SubTB}
"""

from argparse import ArgumentParser

import torch
Expand All @@ -28,10 +27,9 @@
)
from gfn.gym import HyperGrid
from gfn.modules import DiscretePolicyEstimator, ScalarEstimator
from gfn.utils.common import validate
from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular

from gfn.utils.common import set_seed
from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular
from gfn.utils.training import validate

DEFAULT_SEED = 4444

Expand Down Expand Up @@ -225,7 +223,9 @@ def main(args): # noqa: C901
n_iterations = args.n_trajectories // args.batch_size
validation_info = {"l1_dist": float("inf")}
for iteration in trange(n_iterations):
trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size, sample_off_policy=off_policy_sampling)
trajectories = gflownet.sample_trajectories(
env, n_samples=args.batch_size, sample_off_policy=off_policy_sampling
)
training_samples = gflownet.to_training_samples(trajectories)
if replay_buffer is not None:
with torch.no_grad():
Expand Down
9 changes: 4 additions & 5 deletions tutorials/examples/train_line.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import random
from typing import ClassVar, Literal, Tuple

import matplotlib.pyplot as plt
Expand All @@ -15,7 +14,6 @@
from gfn.modules import GFNModule
from gfn.states import States
from gfn.utils import NeuralNet

from gfn.utils.common import set_seed


Expand Down Expand Up @@ -215,7 +213,9 @@ def log_prob(self, sampled_actions):

actions_to_eval[~exit_idx] = sampled_actions[~exit_idx]
if sum(~exit_idx) > 0:
logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[~exit_idx].unsqueeze(-1)
logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[
~exit_idx
].unsqueeze(-1)

return logprobs.squeeze(-1)

Expand Down Expand Up @@ -289,6 +289,7 @@ def to_probability_distribution(
n_steps=self.n_steps_per_trajectory,
)


def train(
gflownet,
env,
Expand Down Expand Up @@ -322,7 +323,6 @@ def train(
scale_schedule = np.linspace(exploration_var_starting_val, 0, n_iterations)

for iteration in tbar:

optimizer.zero_grad()
# Off Policy Sampling.
trajectories = gflownet.sample_trajectories(
Expand Down Expand Up @@ -361,7 +361,6 @@ def train(


if __name__ == "__main__":

environment = Line(
mus=[2, 5],
sigmas=[0.5, 0.5],
Expand Down

0 comments on commit 2bebde2

Please sign in to comment.