diff --git a/src/gfn/gym/helpers/box_utils.py b/src/gfn/gym/helpers/box_utils.py index fa6e7111..14566be5 100644 --- a/src/gfn/gym/helpers/box_utils.py +++ b/src/gfn/gym/helpers/box_utils.py @@ -1,6 +1,6 @@ """This file contains utilitary functions for the Box environment.""" -from typing import Tuple +from typing import Tuple, Any import numpy as np import torch @@ -454,7 +454,7 @@ def __init__( n_hidden_layers: int, n_components_s0: int, n_components: int, - **kwargs, + **kwargs: Any, ): """Instantiates the neural network for the forward policy. @@ -561,7 +561,11 @@ class BoxPBNeuralNet(NeuralNet): """ def __init__( - self, hidden_dim: int, n_hidden_layers: int, n_components: int, **kwargs + self, + hidden_dim: int, + n_hidden_layers: int, + n_components: int, + **kwargs: Any, ): """Instantiates the neural network. @@ -601,7 +605,7 @@ def forward( class BoxStateFlowModule(NeuralNet): """A deep neural network for the state flow function.""" - def __init__(self, logZ_value: torch.Tensor, **kwargs): + def __init__(self, logZ_value: torch.Tensor, **kwargs: Any): super().__init__(**kwargs) self.logZ_value = nn.Parameter(logZ_value) diff --git a/tutorials/examples/train_conditional.py b/tutorials/examples/train_conditional.py index abbff55d..0f5e0adb 100644 --- a/tutorials/examples/train_conditional.py +++ b/tutorials/examples/train_conditional.py @@ -1,6 +1,7 @@ #!/usr/bin/env python import torch from tqdm import tqdm +from torch.optim import Adam from gfn.gflownet import TBGFlowNet, DBGFlowNet, FMGFlowNet, SubTBGFlowNet, ModifiedDBGFlowNet from gfn.gym import HyperGrid @@ -173,17 +174,17 @@ def train(env, gflownet): # Policy parameters and logZ/logF get independent LRs (logF/Z typically higher). if type(gflownet) is TBGFlowNet: - optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=lr) + optimizer = Adam(gflownet.pf_pb_parameters(), lr=lr) optimizer.add_param_group({"params": gflownet.logz_parameters(), "lr": lr * 100}) elif type(gflownet) is DBGFlowNet or type(gflownet) is SubTBGFlowNet: - optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=lr) + optimizer = Adam(gflownet.pf_pb_parameters(), lr=lr) optimizer.add_param_group({"params": gflownet.logF_parameters(), "lr": lr * 100}) elif type(gflownet) is FMGFlowNet or type(gflownet) is ModifiedDBGFlowNet: - optimizer = torch.optim.Adam(gflownet.parameters(), lr=lr) + optimizer = Adam(gflownet.parameters(), lr=lr) else: print("What is this gflownet? {}".format(type(gflownet))) - n_iterations = int(10) #1e4) + n_iterations = int(10) # 1e4) batch_size = int(1e4) print("+ Training Conditional {}!".format(type(gflownet)))