From 6811a8239398c4762de40ce2036dfc4efb9af8bd Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Tue, 20 Feb 2024 17:00:15 -0500 Subject: [PATCH 1/9] added ising example --- tutorials/examples/train_ising.py | 134 ++++++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 tutorials/examples/train_ising.py diff --git a/tutorials/examples/train_ising.py b/tutorials/examples/train_ising.py new file mode 100644 index 00000000..6401dca2 --- /dev/null +++ b/tutorials/examples/train_ising.py @@ -0,0 +1,134 @@ +from argparse import ArgumentParser + +import torch +import wandb +from tqdm import tqdm + +from gfn.gflownet import FMGFlowNet +from gfn.gym import DiscreteEBM +from gfn.gym.discrete_ebm import IsingModel +from gfn.modules import DiscretePolicyEstimator +from gfn.utils.modules import NeuralNet +from gfn.utils.training import validate + + +def main(args): + # Configs + + use_wandb = len(args.wandb_project) > 0 + if use_wandb: + wandb.init(project=args.wandb_project) + wandb.config.update(args) + + device = "cpu" + torch.set_num_threads(args.n_threads) + hidden_dim = 512 + + n_hidden = 2 + acc_fn = "relu" + lr = 0.001 + lr_Z = 0.01 + validation_samples = 1000 + + def make_J(L, coupling_constant): + """Ising model parameters.""" + + def ising_n_to_ij(L, n): + i = n // L + j = n - i * L + return (i, j) + + N = L**2 + J = torch.zeros((N, N), device=torch.device(device)) + for k in range(N): + for m in range(k): + x1, y1 = ising_n_to_ij(L, k) + x2, y2 = ising_n_to_ij(L, m) + if x1 == x2 and abs(y2 - y1) == 1: + J[k][m] = 1 + J[m][k] = 1 + elif y1 == y2 and abs(x2 - x1) == 1: + J[k][m] = 1 + J[m][k] = 1 + + for k in range(L): + J[k * L][(k + 1) * L - 1] = 1 + J[(k + 1) * L - 1][k * L] = 1 + J[k][k + N - L] = 1 + J[k + N - L][k] = 1 + + return coupling_constant * J + + # Ising model env + N = args.L**2 + J = make_J(args.L, args.J) + ising_energy = IsingModel(J) + env = DiscreteEBM(N, alpha=1, energy=ising_energy, device_str=device) + + # Parametrization and losses + pf_module = NeuralNet( + input_dim=env.preprocessor.output_dim, + output_dim=env.n_actions, + hidden_dim=hidden_dim, + n_hidden_layers=n_hidden, + activation_fn=acc_fn, + ) + + pf_estimator = DiscretePolicyEstimator( + pf_module, env.n_actions, env.preprocessor, is_backward=False + ) + gflownet = FMGFlowNet(pf_estimator) + optimizer = torch.optim.Adam(gflownet.parameters(), lr=1e-3) + + # Learning + visited_terminating_states = env.States.from_batch_shape((0,)) + states_visited = 0 + for i in (pbar := tqdm(range(10000))): + trajectories = gflownet.sample_trajectories(env, n_samples=8, off_policy=False) + training_samples = gflownet.to_training_samples(trajectories) + optimizer.zero_grad() + loss = gflownet.loss(env, training_samples) + loss.backward() + optimizer.step() + + states_visited += len(trajectories) + to_log = {"loss": loss.item(), "states_visited": states_visited} + + if i % 25 == 0: + tqdm.write(f"{i}: {to_log}") + + +if __name__ == "__main__": + # Comand-line arguments + parser = ArgumentParser() + + parser.add_argument( + "--n_threads", + type=int, + default=4, + help="Number of threads used by PyTorch", + ) + + parser.add_argument( + "-L", + type=int, + default=16, + help="Lentgh of the grid", + ) + + parser.add_argument( + "-J", + type=float, + default=0.44, + help="J (Magnetic coupling constant)", + ) + + parser.add_argument( + "--wandb_project", + type=str, + default="", + help="Name of the wandb project. If empty, don't use wandb", + ) + + args = parser.parse_args() + main(args) From 45d9893962901b876304758d2df9f39f24b4eeae Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 15:49:44 -0500 Subject: [PATCH 2/9] black --- src/gfn/gflownet/base.py | 1 + testing/test_environments.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index e38bb10a..9bd216cf 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -24,6 +24,7 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]): A formal definition of GFlowNets is given in Sec. 3 of [GFlowNet Foundations](https://arxiv.org/pdf/2111.09266). """ + log_reward_clip_min = float("-inf") # Default off. @abstractmethod diff --git a/testing/test_environments.py b/testing/test_environments.py index b110baac..5dbd4cc6 100644 --- a/testing/test_environments.py +++ b/testing/test_environments.py @@ -209,7 +209,9 @@ def test_box_fwd_step(delta: float): ] for failing_actions_list in failing_actions_lists_at_s0: - actions = env.actions_from_tensor(format_tensor(failing_actions_list, discrete=False)) + actions = env.actions_from_tensor( + format_tensor(failing_actions_list, discrete=False) + ) with pytest.raises(NonValidActionsError): states = env._step(states, actions) From 1e72273edaad50e6aa90aa5cfd14788b8010617d Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 16:10:41 -0500 Subject: [PATCH 3/9] default value reduced for grid size --- tutorials/examples/train_ising.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/examples/train_ising.py b/tutorials/examples/train_ising.py index 6401dca2..26ca2864 100644 --- a/tutorials/examples/train_ising.py +++ b/tutorials/examples/train_ising.py @@ -112,7 +112,7 @@ def ising_n_to_ij(L, n): parser.add_argument( "-L", type=int, - default=16, + default=6, help="Lentgh of the grid", ) From c8cf89c64656b7d4c9a7258a7cf5cc59f31b4706 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 16:10:53 -0500 Subject: [PATCH 4/9] typo --- tutorials/examples/train_ising.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/examples/train_ising.py b/tutorials/examples/train_ising.py index 26ca2864..1ca2c656 100644 --- a/tutorials/examples/train_ising.py +++ b/tutorials/examples/train_ising.py @@ -113,7 +113,7 @@ def ising_n_to_ij(L, n): "-L", type=int, default=6, - help="Lentgh of the grid", + help="Length of the grid", ) parser.add_argument( From 1846da1c9400c7526e2e26707039a97227400d6e Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 16:38:32 -0500 Subject: [PATCH 5/9] black upgrade --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 957a60ce..539e3cb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ torch = ">=1.9.0" torchtyping = ">=0.1.4" # dev dependencies. -black = { version = "*", optional = true } +black = { version = "24.2", optional = true } flake8 = { version = "*", optional = true } gitmopy = { version = "*", optional = true } myst-parser = { version = "*", optional = true } From 552e010bfe2c0088c15c143cb68614bd758b2c14 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 16:42:09 -0500 Subject: [PATCH 6/9] upgrade black --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 539e3cb6..3d05ed8c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,7 @@ all = [ "Bug Tracker" = "https://github.com/saleml/gfn/issues" [tool.black] -py36 = true +target_version = ["py310"] include = '\.pyi?$' exclude = '''/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|build)/g''' From 21b845d48aafa1ec202a1c21feb2fd1bf8409824 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 16:42:42 -0500 Subject: [PATCH 7/9] black --- src/gfn/gflownet/base.py | 4 +--- src/gfn/gflownet/detailed_balance.py | 4 +--- src/gfn/gym/helpers/box_utils.py | 1 + src/gfn/gym/hypergrid.py | 1 + 4 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 9bd216cf..ece89bc3 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -201,9 +201,7 @@ def get_pfs_and_pbs( return log_pf_trajectories, log_pb_trajectories - def get_trajectories_scores( - self, trajectories: Trajectories - ) -> Tuple[ + def get_trajectories_scores(self, trajectories: Trajectories) -> Tuple[ TT["n_trajectories", torch.float], TT["n_trajectories", torch.float], TT["n_trajectories", torch.float], diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index 4cb4e6e2..2c9cc723 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -42,9 +42,7 @@ def __init__( self.forward_looking = forward_looking self.log_reward_clip_min = log_reward_clip_min - def get_scores( - self, env: Env, transitions: Transitions - ) -> Tuple[ + def get_scores(self, env: Env, transitions: Transitions) -> Tuple[ TT["n_transitions", float], TT["n_transitions", float], TT["n_transitions", float], diff --git a/src/gfn/gym/helpers/box_utils.py b/src/gfn/gym/helpers/box_utils.py index c6342c75..bc5b18f2 100644 --- a/src/gfn/gym/helpers/box_utils.py +++ b/src/gfn/gym/helpers/box_utils.py @@ -1,4 +1,5 @@ """This file contains utilitary functions for the Box environment.""" + from typing import Tuple import numpy as np diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index b8bf27d1..9d6d7d0f 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -1,6 +1,7 @@ """ Copied and Adapted from https://github.com/Tikquuss/GflowNets_Tutorial """ + from typing import Literal, Tuple import torch From 6aa1659d6238d1a366222af884dfbec2ee4b40cd Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 16:46:17 -0500 Subject: [PATCH 8/9] black formatting update --- tutorials/examples/train_box.py | 9 ++++++--- tutorials/examples/train_discreteebm.py | 1 + tutorials/examples/train_hypergrid.py | 1 + 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 5a3cf8dd..632d5b78 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -6,6 +6,7 @@ python train_box.py --delta {0.1, 0.25} --tied {--uniform_pb} --loss {TB, DB} """ + from argparse import ArgumentParser import numpy as np @@ -189,9 +190,11 @@ def main(args): # noqa: C901 if not args.uniform_pb: optimizer.add_param_group( { - "params": pb_module.last_layer.parameters() - if args.tied - else pb_module.parameters(), + "params": ( + pb_module.last_layer.parameters() + if args.tied + else pb_module.parameters() + ), "lr": args.lr, } ) diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index 3574fa2d..562bb2b4 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -10,6 +10,7 @@ [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid.py --ndim {2, 4} --height 12 --R0 {1e-3, 1e-4} --tied --loss {TB, DB, SubTB} """ + from argparse import ArgumentParser import torch diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 4d4e3a25..f52932e9 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -10,6 +10,7 @@ [Learning GFlowNets from partial episodes for improved convergence and stability](https://arxiv.org/abs/2209.12782) python train_hypergrid.py --ndim {2, 4} --height 12 --R0 {1e-3, 1e-4} --tied --loss {TB, DB, SubTB} """ + from argparse import ArgumentParser import torch From f1a5c7f016b36d1c2ba2809da7eeb7add63c6b1e Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Sat, 24 Feb 2024 16:53:02 -0500 Subject: [PATCH 9/9] extended excludes --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3d05ed8c..36947af0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ all = [ [tool.black] target_version = ["py310"] include = '\.pyi?$' -exclude = '''/(\.git|\.hg|\.mypy_cache|\.tox|\.venv|build)/g''' +extend-exclude = '''/(\.git|\.hg|\.mypy_cache|\.ipynb|\.tox|\.venv|build)/g''' [tool.tox] legacy_tox_ini = '''