diff --git a/src/gfn/states.py b/src/gfn/states.py index 416e5670..53492861 100644 --- a/src/gfn/states.py +++ b/src/gfn/states.py @@ -410,9 +410,9 @@ def set_nonexit_action_masks(self, cond, allow_exit: bool): trajectory - if so, it should be set to True. """ if allow_exit: - exit_idx = torch.zeros(self.batch_shape + (1,)) + exit_idx = torch.zeros(self.batch_shape + (1,)).to(cond.device) else: - exit_idx = torch.ones(self.batch_shape + (1,)) + exit_idx = torch.ones(self.batch_shape + (1,)).to(cond.device) self.forward_masks[torch.cat([cond, exit_idx], dim=-1).bool()] = False def set_exit_masks(self, batch_idx): diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index 2ffbf54a..9820fa05 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -54,7 +54,6 @@ def __init__( else: self.torso = torso self.last_layer = nn.Linear(self.torso.hidden_dim, output_dim) - self.device = None def forward( self, preprocessed_states: TT["batch_shape", "input_dim", float] @@ -66,11 +65,6 @@ def forward( ingestion by the MLP. Returns: out, a set of continuous variables. """ - if self.device is None: - self.device = preprocessed_states.device - self.to( - self.device - ) # TODO: This is maybe fine but could result in weird errors if the model keeps bouncing between devices. out = self.torso(preprocessed_states) out = self.last_layer(out) return out diff --git a/tutorials/examples/train_hypergrid_simple.py b/tutorials/examples/train_hypergrid_simple.py new file mode 100644 index 00000000..d21ef349 --- /dev/null +++ b/tutorials/examples/train_hypergrid_simple.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python +import torch +from tqdm import tqdm + +from gfn.gflownet import TBGFlowNet +from gfn.gym import HyperGrid +from gfn.modules import DiscretePolicyEstimator +from gfn.samplers import Sampler +from gfn.utils import NeuralNet + +torch.manual_seed(0) +exploration_rate = 0.5 +learning_rate = 0.0005 + +# Setup the Environment. +env = HyperGrid( + ndim=5, + height=2, + device_str="cuda" if torch.cuda.is_available() else "cpu", +) + +# Build the GFlowNet. +module_PF = NeuralNet( + input_dim=env.preprocessor.output_dim, + output_dim=env.n_actions, +) +module_PB = NeuralNet( + input_dim=env.preprocessor.output_dim, + output_dim=env.n_actions - 1, + torso=module_PF.torso, +) +pf_estimator = DiscretePolicyEstimator( + module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor +) +pb_estimator = DiscretePolicyEstimator( + module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor +) +gflownet = TBGFlowNet(init_logZ=0.0, pf=pf_estimator, pb=pb_estimator, off_policy=True) + +# Feed pf to the sampler. +sampler = Sampler(estimator=pf_estimator) + +# Move the gflownet to the GPU. +if torch.cuda.is_available(): + gflownet = gflownet.to("cuda") + +# Policy parameters have their own LR. Log Z gets dedicated learning rate +# (typically higher). +optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=1e-3) +optimizer.add_param_group({"params": gflownet.logz_parameters(), "lr": 1e-1}) + +n_iterations = int(1e4) +batch_size = int(1e5) + +for i in (pbar := tqdm(range(n_iterations))): + trajectories = sampler.sample_trajectories( + env, + n_trajectories=batch_size, + off_policy=True, + epsilon=exploration_rate, + ) + optimizer.zero_grad() + loss = gflownet.loss(env, trajectories) + loss.backward() + optimizer.step() + pbar.set_postfix({"loss": loss.item()})