diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 2b90d00e..4f624573 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -80,7 +80,7 @@ def to_training_samples(self, trajectories: Trajectories) -> TrainingSampleType: """Converts trajectories to training samples. The type depends on the GFlowNet.""" @abstractmethod - def loss(self, env: Env, training_objects): + def loss(self, env: Env, training_objects) -> Tensor: """Computes the loss given the training objects.""" diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index c0e9f2d8..c0819cef 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -8,6 +8,7 @@ """ from argparse import ArgumentParser +from typing import Tuple import numpy as np import torch @@ -24,6 +25,7 @@ SubTBGFlowNet, TBGFlowNet, ) +from gfn.gflownet.base import GFlowNet from gfn.gym import Box from gfn.gym.helpers.box_utils import ( BoxPBEstimator, @@ -87,25 +89,9 @@ def estimate_jsd(kde1, kde2): return jsd / 2.0 -def main(args): # noqa: C901 - seed = args.seed if args.seed != 0 else DEFAULT_SEED - set_seed(seed) - - device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" - - use_wandb = len(args.wandb_project) > 0 - if use_wandb: - wandb.init(project=args.wandb_project) - wandb.config.update(args) - - n_iterations = args.n_trajectories // args.batch_size - - # 1. Create the environment - env = Box(delta=args.delta, epsilon=1e-10, device_str=device_str) - - # 2. Create the gflownet. - # For this we need modules and estimators. - # Depending on the loss, we may need several estimators: +def make_gflownet( + args, env +) -> Tuple[GFlowNet, BoxPFNeuralNet, BoxPBNeuralNet | None, BoxStateFlowModule | None]: gflownet = None pf_module = BoxPFNeuralNet( hidden_dim=args.hidden_dim, @@ -181,11 +167,35 @@ def main(args): # noqa: C901 ) assert gflownet is not None, f"No gflownet for loss {args.loss}" + return gflownet, pf_module, pb_module, module + + +def main(args): # noqa: C901 + seed = args.seed if args.seed != 0 else DEFAULT_SEED + set_seed(seed) + + device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" + + use_wandb = len(args.wandb_project) > 0 + if use_wandb: + wandb.init(project=args.wandb_project) + wandb.config.update(args) + + n_iterations = args.n_trajectories // args.batch_size + + # 1. Create the environment + env = Box(delta=args.delta, epsilon=1e-10, device_str=device_str) + + # 2. Create the gflownet. + # For this we need modules and estimators. + # Depending on the loss, we may need several estimators: + gflownet, pf_module, pb_module, module = make_gflownet(args, env) # 3. Create the optimizer and scheduler optimizer = torch.optim.Adam(pf_module.parameters(), lr=args.lr) if not args.uniform_pb: + assert pb_module is not None optimizer.add_param_group( { "params": ( @@ -235,14 +245,9 @@ def main(args): # noqa: C901 trajectories = gflownet.sample_trajectories( env, save_logprobs=True, n_samples=args.batch_size ) - training_samples = gflownet.to_training_samples(trajectories) optimizer.zero_grad() - if isinstance(gflownet, DBGFlowNet): - assert isinstance(training_samples, Transitions) - loss = gflownet.loss(env, training_samples) - else: - assert isinstance(training_samples, Trajectories) - loss = gflownet.loss(env, training_samples) + training_samples = gflownet.to_training_samples(trajectories) + loss = gflownet.loss(env, training_samples) loss.backward() for p in gflownet.parameters(): diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 05240666..97025772 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -41,27 +41,7 @@ DEFAULT_SEED = 4444 -def main(args): # noqa: C901 - seed = args.seed if args.seed != 0 else DEFAULT_SEED - set_seed(seed) - device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" - - use_wandb = len(args.wandb_project) > 0 - if use_wandb: - wandb.init(project=args.wandb_project) - wandb.config.update(args) - - # 1. Create the environment - env = HyperGrid( - args.ndim, args.height, args.R0, args.R1, args.R2, device_str=device_str - ) - - # 2. Create the gflownets. - # For this we need modules and estimators. - # Depending on the loss, we may need several estimators: - # one (forward only) for FM loss, - # two (forward and backward) or other losses - # three (same, + logZ) estimators for TB. +def _make_gflownet(args, env) -> GFlowNet: gflownet: Optional[GFlowNet] = None if args.loss == "FM": # We need a LogEdgeFlowEstimator @@ -179,6 +159,26 @@ def main(args): # noqa: C901 ) assert gflownet is not None, f"No gflownet for loss {args.loss}" + return gflownet + + +def main(args): + seed = args.seed if args.seed != 0 else DEFAULT_SEED + set_seed(seed) + device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" + + use_wandb = len(args.wandb_project) > 0 + if use_wandb: + wandb.init(project=args.wandb_project) + wandb.config.update(args) + + # 1. Create the environment + env = HyperGrid( + args.ndim, args.height, args.R0, args.R1, args.R2, device_str=device_str + ) + + # 2. Create the gflownets. + gflownet = _make_gflownet(args, env) # Initialize the replay buffer ? replay_buffer = None