Skip to content

Commit

Permalink
typing
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Oct 9, 2024
1 parent 6b47e06 commit 988faf0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
12 changes: 8 additions & 4 deletions src/gfn/gym/helpers/box_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""This file contains utilitary functions for the Box environment."""

from typing import Tuple
from typing import Tuple, Any

import numpy as np
import torch
Expand Down Expand Up @@ -454,7 +454,7 @@ def __init__(
n_hidden_layers: int,
n_components_s0: int,
n_components: int,
**kwargs,
**kwargs: Any,
):
"""Instantiates the neural network for the forward policy.
Expand Down Expand Up @@ -561,7 +561,11 @@ class BoxPBNeuralNet(NeuralNet):
"""

def __init__(
self, hidden_dim: int, n_hidden_layers: int, n_components: int, **kwargs
self,
hidden_dim: int,
n_hidden_layers: int,
n_components: int,
**kwargs: Any,
):
"""Instantiates the neural network.
Expand Down Expand Up @@ -601,7 +605,7 @@ def forward(
class BoxStateFlowModule(NeuralNet):
"""A deep neural network for the state flow function."""

def __init__(self, logZ_value: torch.Tensor, **kwargs):
def __init__(self, logZ_value: torch.Tensor, **kwargs: Any):
super().__init__(**kwargs)
self.logZ_value = nn.Parameter(logZ_value)

Expand Down
9 changes: 5 additions & 4 deletions tutorials/examples/train_conditional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python
import torch
from tqdm import tqdm
from torch.optim import Adam

from gfn.gflownet import TBGFlowNet, DBGFlowNet, FMGFlowNet, SubTBGFlowNet, ModifiedDBGFlowNet
from gfn.gym import HyperGrid
Expand Down Expand Up @@ -173,17 +174,17 @@ def train(env, gflownet):

# Policy parameters and logZ/logF get independent LRs (logF/Z typically higher).
if type(gflownet) is TBGFlowNet:
optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=lr)
optimizer = Adam(gflownet.pf_pb_parameters(), lr=lr)
optimizer.add_param_group({"params": gflownet.logz_parameters(), "lr": lr * 100})
elif type(gflownet) is DBGFlowNet or type(gflownet) is SubTBGFlowNet:
optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=lr)
optimizer = Adam(gflownet.pf_pb_parameters(), lr=lr)
optimizer.add_param_group({"params": gflownet.logF_parameters(), "lr": lr * 100})
elif type(gflownet) is FMGFlowNet or type(gflownet) is ModifiedDBGFlowNet:
optimizer = torch.optim.Adam(gflownet.parameters(), lr=lr)
optimizer = Adam(gflownet.parameters(), lr=lr)
else:
print("What is this gflownet? {}".format(type(gflownet)))

n_iterations = int(10) #1e4)
n_iterations = int(10) # 1e4)
batch_size = int(1e4)

print("+ Training Conditional {}!".format(type(gflownet)))
Expand Down

0 comments on commit 988faf0

Please sign in to comment.