From 4434e5f9f28491199a73f02aca2893747146df8d Mon Sep 17 00:00:00 2001 From: Joseph Viviano <joseph@viviano.ca> Date: Tue, 1 Oct 2024 12:33:17 -0400 Subject: [PATCH] API changes --- tutorials/examples/train_box.py | 5 +++-- tutorials/examples/train_discreteebm.py | 3 +-- tutorials/examples/train_hypergrid.py | 2 +- tutorials/examples/train_hypergrid_simple.py | 12 ++++-------- .../examples/train_hypergrid_simple_conditional.py | 4 ++-- tutorials/examples/train_ising.py | 8 +++++++- tutorials/examples/train_line.py | 4 ++-- 7 files changed, 20 insertions(+), 18 deletions(-) diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 8bf7ec5b..b6eeedc6 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -230,8 +230,9 @@ def main(args): # noqa: C901 if iteration % 1000 == 0: print(f"current optimizer LR: {optimizer.param_groups[0]['lr']}") + # Sampling on-policy, so we save logprobs for faster computation. trajectories = gflownet.sample_trajectories( - env, save_logprobs=True, n_samples=args.batch_size + env, save_logprobs=True, n=args.batch_size ) training_samples = gflownet.to_training_samples(trajectories) @@ -241,7 +242,7 @@ def main(args): # noqa: C901 loss.backward() for p in gflownet.parameters(): - if p.ndim > 0 and p.grad is not None: # We do not clip logZ grad + if p.ndim > 0 and p.grad is not None: # We do not clip logZ grad. p.grad.data.clamp_(-10, 10).nan_to_num_(0.0) optimizer.step() scheduler.step() diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index 45537686..9bac6c26 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -63,7 +63,6 @@ def main(args): # noqa: C901 optimizer = torch.optim.Adam(module.parameters(), lr=args.lr) # 4. Train the gflownet - visited_terminating_states = env.states_from_batch_shape((0,)) states_visited = 0 @@ -71,7 +70,7 @@ def main(args): # noqa: C901 validation_info = {"l1_dist": float("inf")} for iteration in trange(n_iterations): trajectories = gflownet.sample_trajectories( - env, save_logprobs=True, n_samples=args.batch_size + env, save_logprobs=True, n=args.batch_size ) training_samples = gflownet.to_training_samples(trajectories) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index eec3366b..c89f3274 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -229,7 +229,7 @@ def main(args): # noqa: C901 for iteration in trange(n_iterations): trajectories = gflownet.sample_trajectories( env, - n_samples=args.batch_size, + n=args.batch_size, save_logprobs=args.replay_buffer_size == 0, save_estimator_outputs=False, ) diff --git a/tutorials/examples/train_hypergrid_simple.py b/tutorials/examples/train_hypergrid_simple.py index 826eebca..d2d5bccc 100644 --- a/tutorials/examples/train_hypergrid_simple.py +++ b/tutorials/examples/train_hypergrid_simple.py @@ -5,7 +5,6 @@ from gfn.gflownet import TBGFlowNet from gfn.gym import HyperGrid from gfn.modules import DiscretePolicyEstimator -from gfn.samplers import Sampler from gfn.utils import NeuralNet torch.manual_seed(0) @@ -35,10 +34,7 @@ pb_estimator = DiscretePolicyEstimator( module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor ) -gflownet = TBGFlowNet(init_logZ=0.0, pf=pf_estimator, pb=pb_estimator) - -# Feed pf to the sampler. -sampler = Sampler(estimator=pf_estimator) +gflownet = TBGFlowNet(logZ=0.0, pf=pf_estimator, pb=pb_estimator) # Move the gflownet to the GPU. if torch.cuda.is_available(): @@ -53,9 +49,9 @@ batch_size = int(1e5) for i in (pbar := tqdm(range(n_iterations))): - trajectories = sampler.sample_trajectories( + trajectories = gflownet.sample_trajectories( env, - n_trajectories=batch_size, + n=batch_size, save_logprobs=False, save_estimator_outputs=True, epsilon=exploration_rate, @@ -64,4 +60,4 @@ loss = gflownet.loss(env, trajectories) loss.backward() optimizer.step() - pbar.set_postfix({"loss": loss.item()}) \ No newline at end of file + pbar.set_postfix({"loss": loss.item()}) diff --git a/tutorials/examples/train_hypergrid_simple_conditional.py b/tutorials/examples/train_hypergrid_simple_conditional.py index 781c364c..d17d6227 100644 --- a/tutorials/examples/train_hypergrid_simple_conditional.py +++ b/tutorials/examples/train_hypergrid_simple_conditional.py @@ -97,9 +97,9 @@ conditioning = torch.rand((batch_size, 1)) conditioning = (conditioning > 0.5).to(torch.float) # Randomly 1 and zero. - trajectories = sampler.sample_trajectories( + trajectories = gflownet.sample_trajectories( env, - n_trajectories=batch_size, + n=batch_size, conditioning=conditioning, save_logprobs=False, save_estimator_outputs=True, diff --git a/tutorials/examples/train_ising.py b/tutorials/examples/train_ising.py index 1ca2c656..878c11cf 100644 --- a/tutorials/examples/train_ising.py +++ b/tutorials/examples/train_ising.py @@ -83,8 +83,14 @@ def ising_n_to_ij(L, n): # Learning visited_terminating_states = env.States.from_batch_shape((0,)) states_visited = 0 + for i in (pbar := tqdm(range(10000))): - trajectories = gflownet.sample_trajectories(env, n_samples=8, off_policy=False) + trajectories = gflownet.sample_trajectories( + env, + n=8, + save_estimator_outputs=False, + save_logprobs=True, + ) training_samples = gflownet.to_training_samples(trajectories) optimizer.zero_grad() loss = gflownet.loss(env, training_samples) diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py index 6ce7fde6..c43115f9 100644 --- a/tutorials/examples/train_line.py +++ b/tutorials/examples/train_line.py @@ -227,7 +227,7 @@ def train( # Off Policy Sampling. trajectories = gflownet.sample_trajectories( env, - n_samples=batch_size, + n=batch_size, save_estimator_outputs=True, save_logprobs=False, scale_factor=scale_schedule[iteration], # Off policy kwargs. @@ -292,7 +292,7 @@ def train( policy_std_max=policy_std_max, ) pb = StepEstimator(environment, pb_module, backward=True) - gflownet = TBGFlowNet(pf=pf, pb=pb, init_logZ=0.0) + gflownet = TBGFlowNet(pf=pf, pb=pb, logZ=0.0) gflownet = train( gflownet,