Skip to content

Commit

Permalink
fix pyright
Browse files Browse the repository at this point in the history
  • Loading branch information
younik committed Oct 12, 2024
1 parent 68b5c04 commit 28e25b8
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 48 deletions.
2 changes: 1 addition & 1 deletion src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def to_training_samples(self, trajectories: Trajectories) -> TrainingSampleType:
"""Converts trajectories to training samples. The type depends on the GFlowNet."""

@abstractmethod
def loss(self, env: Env, training_objects):
def loss(self, env: Env, training_objects) -> Tensor:
"""Computes the loss given the training objects."""


Expand Down
57 changes: 31 additions & 26 deletions tutorials/examples/train_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

from argparse import ArgumentParser
from typing import Tuple

import numpy as np
import torch
Expand All @@ -24,6 +25,7 @@
SubTBGFlowNet,
TBGFlowNet,
)
from gfn.gflownet.base import GFlowNet
from gfn.gym import Box
from gfn.gym.helpers.box_utils import (
BoxPBEstimator,
Expand Down Expand Up @@ -87,25 +89,9 @@ def estimate_jsd(kde1, kde2):
return jsd / 2.0


def main(args): # noqa: C901
seed = args.seed if args.seed != 0 else DEFAULT_SEED
set_seed(seed)

device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"

use_wandb = len(args.wandb_project) > 0
if use_wandb:
wandb.init(project=args.wandb_project)
wandb.config.update(args)

n_iterations = args.n_trajectories // args.batch_size

# 1. Create the environment
env = Box(delta=args.delta, epsilon=1e-10, device_str=device_str)

# 2. Create the gflownet.
# For this we need modules and estimators.
# Depending on the loss, we may need several estimators:
def make_gflownet(
args, env
) -> Tuple[GFlowNet, BoxPFNeuralNet, BoxPBNeuralNet | None, BoxStateFlowModule | None]:
gflownet = None
pf_module = BoxPFNeuralNet(
hidden_dim=args.hidden_dim,
Expand Down Expand Up @@ -181,11 +167,35 @@ def main(args): # noqa: C901
)

assert gflownet is not None, f"No gflownet for loss {args.loss}"
return gflownet, pf_module, pb_module, module


def main(args): # noqa: C901
seed = args.seed if args.seed != 0 else DEFAULT_SEED
set_seed(seed)

device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"

use_wandb = len(args.wandb_project) > 0
if use_wandb:
wandb.init(project=args.wandb_project)
wandb.config.update(args)

n_iterations = args.n_trajectories // args.batch_size

# 1. Create the environment
env = Box(delta=args.delta, epsilon=1e-10, device_str=device_str)

# 2. Create the gflownet.
# For this we need modules and estimators.
# Depending on the loss, we may need several estimators:
gflownet, pf_module, pb_module, module = make_gflownet(args, env)

# 3. Create the optimizer and scheduler

optimizer = torch.optim.Adam(pf_module.parameters(), lr=args.lr)
if not args.uniform_pb:
assert pb_module is not None
optimizer.add_param_group(
{
"params": (
Expand Down Expand Up @@ -235,14 +245,9 @@ def main(args): # noqa: C901
trajectories = gflownet.sample_trajectories(
env, save_logprobs=True, n_samples=args.batch_size
)
training_samples = gflownet.to_training_samples(trajectories)
optimizer.zero_grad()
if isinstance(gflownet, DBGFlowNet):
assert isinstance(training_samples, Transitions)
loss = gflownet.loss(env, training_samples)
else:
assert isinstance(training_samples, Trajectories)
loss = gflownet.loss(env, training_samples)
training_samples = gflownet.to_training_samples(trajectories)
loss = gflownet.loss(env, training_samples)

loss.backward()
for p in gflownet.parameters():
Expand Down
42 changes: 21 additions & 21 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,7 @@
DEFAULT_SEED = 4444


def main(args): # noqa: C901
seed = args.seed if args.seed != 0 else DEFAULT_SEED
set_seed(seed)
device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"

use_wandb = len(args.wandb_project) > 0
if use_wandb:
wandb.init(project=args.wandb_project)
wandb.config.update(args)

# 1. Create the environment
env = HyperGrid(
args.ndim, args.height, args.R0, args.R1, args.R2, device_str=device_str
)

# 2. Create the gflownets.
# For this we need modules and estimators.
# Depending on the loss, we may need several estimators:
# one (forward only) for FM loss,
# two (forward and backward) or other losses
# three (same, + logZ) estimators for TB.
def _make_gflownet(args, env) -> GFlowNet:
gflownet: Optional[GFlowNet] = None
if args.loss == "FM":
# We need a LogEdgeFlowEstimator
Expand Down Expand Up @@ -179,6 +159,26 @@ def main(args): # noqa: C901
)

assert gflownet is not None, f"No gflownet for loss {args.loss}"
return gflownet


def main(args):
seed = args.seed if args.seed != 0 else DEFAULT_SEED
set_seed(seed)
device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"

use_wandb = len(args.wandb_project) > 0
if use_wandb:
wandb.init(project=args.wandb_project)
wandb.config.update(args)

# 1. Create the environment
env = HyperGrid(
args.ndim, args.height, args.R0, args.R1, args.R2, device_str=device_str
)

# 2. Create the gflownets.
gflownet = _make_gflownet(args, env)

# Initialize the replay buffer ?
replay_buffer = None
Expand Down

0 comments on commit 28e25b8

Please sign in to comment.