From 4434e5f9f28491199a73f02aca2893747146df8d Mon Sep 17 00:00:00 2001
From: Joseph Viviano <joseph@viviano.ca>
Date: Tue, 1 Oct 2024 12:33:17 -0400
Subject: [PATCH] API changes

---
 tutorials/examples/train_box.py                      |  5 +++--
 tutorials/examples/train_discreteebm.py              |  3 +--
 tutorials/examples/train_hypergrid.py                |  2 +-
 tutorials/examples/train_hypergrid_simple.py         | 12 ++++--------
 .../examples/train_hypergrid_simple_conditional.py   |  4 ++--
 tutorials/examples/train_ising.py                    |  8 +++++++-
 tutorials/examples/train_line.py                     |  4 ++--
 7 files changed, 20 insertions(+), 18 deletions(-)

diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py
index 8bf7ec5b..b6eeedc6 100644
--- a/tutorials/examples/train_box.py
+++ b/tutorials/examples/train_box.py
@@ -230,8 +230,9 @@ def main(args):  # noqa: C901
         if iteration % 1000 == 0:
             print(f"current optimizer LR: {optimizer.param_groups[0]['lr']}")
 
+        # Sampling on-policy, so we save logprobs for faster computation.
         trajectories = gflownet.sample_trajectories(
-            env, save_logprobs=True, n_samples=args.batch_size
+            env, save_logprobs=True, n=args.batch_size
         )
 
         training_samples = gflownet.to_training_samples(trajectories)
@@ -241,7 +242,7 @@ def main(args):  # noqa: C901
 
         loss.backward()
         for p in gflownet.parameters():
-            if p.ndim > 0 and p.grad is not None:  # We do not clip logZ grad
+            if p.ndim > 0 and p.grad is not None:  # We do not clip logZ grad.
                 p.grad.data.clamp_(-10, 10).nan_to_num_(0.0)
         optimizer.step()
         scheduler.step()
diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py
index 45537686..9bac6c26 100644
--- a/tutorials/examples/train_discreteebm.py
+++ b/tutorials/examples/train_discreteebm.py
@@ -63,7 +63,6 @@ def main(args):  # noqa: C901
     optimizer = torch.optim.Adam(module.parameters(), lr=args.lr)
 
     # 4. Train the gflownet
-
     visited_terminating_states = env.states_from_batch_shape((0,))
 
     states_visited = 0
@@ -71,7 +70,7 @@ def main(args):  # noqa: C901
     validation_info = {"l1_dist": float("inf")}
     for iteration in trange(n_iterations):
         trajectories = gflownet.sample_trajectories(
-            env, save_logprobs=True, n_samples=args.batch_size
+            env, save_logprobs=True, n=args.batch_size
         )
         training_samples = gflownet.to_training_samples(trajectories)
 
diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py
index eec3366b..c89f3274 100644
--- a/tutorials/examples/train_hypergrid.py
+++ b/tutorials/examples/train_hypergrid.py
@@ -229,7 +229,7 @@ def main(args):  # noqa: C901
     for iteration in trange(n_iterations):
         trajectories = gflownet.sample_trajectories(
             env,
-            n_samples=args.batch_size,
+            n=args.batch_size,
             save_logprobs=args.replay_buffer_size == 0,
             save_estimator_outputs=False,
         )
diff --git a/tutorials/examples/train_hypergrid_simple.py b/tutorials/examples/train_hypergrid_simple.py
index 826eebca..d2d5bccc 100644
--- a/tutorials/examples/train_hypergrid_simple.py
+++ b/tutorials/examples/train_hypergrid_simple.py
@@ -5,7 +5,6 @@
 from gfn.gflownet import TBGFlowNet
 from gfn.gym import HyperGrid
 from gfn.modules import DiscretePolicyEstimator
-from gfn.samplers import Sampler
 from gfn.utils import NeuralNet
 
 torch.manual_seed(0)
@@ -35,10 +34,7 @@
 pb_estimator = DiscretePolicyEstimator(
     module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor
 )
-gflownet = TBGFlowNet(init_logZ=0.0, pf=pf_estimator, pb=pb_estimator)
-
-# Feed pf to the sampler.
-sampler = Sampler(estimator=pf_estimator)
+gflownet = TBGFlowNet(logZ=0.0, pf=pf_estimator, pb=pb_estimator)
 
 # Move the gflownet to the GPU.
 if torch.cuda.is_available():
@@ -53,9 +49,9 @@
 batch_size = int(1e5)
 
 for i in (pbar := tqdm(range(n_iterations))):
-    trajectories = sampler.sample_trajectories(
+    trajectories = gflownet.sample_trajectories(
         env,
-        n_trajectories=batch_size,
+        n=batch_size,
         save_logprobs=False,
         save_estimator_outputs=True,
         epsilon=exploration_rate,
@@ -64,4 +60,4 @@
     loss = gflownet.loss(env, trajectories)
     loss.backward()
     optimizer.step()
-    pbar.set_postfix({"loss": loss.item()})
\ No newline at end of file
+    pbar.set_postfix({"loss": loss.item()})
diff --git a/tutorials/examples/train_hypergrid_simple_conditional.py b/tutorials/examples/train_hypergrid_simple_conditional.py
index 781c364c..d17d6227 100644
--- a/tutorials/examples/train_hypergrid_simple_conditional.py
+++ b/tutorials/examples/train_hypergrid_simple_conditional.py
@@ -97,9 +97,9 @@
     conditioning = torch.rand((batch_size, 1))
     conditioning = (conditioning > 0.5).to(torch.float)  # Randomly 1 and zero.
 
-    trajectories = sampler.sample_trajectories(
+    trajectories = gflownet.sample_trajectories(
         env,
-        n_trajectories=batch_size,
+        n=batch_size,
         conditioning=conditioning,
         save_logprobs=False,
         save_estimator_outputs=True,
diff --git a/tutorials/examples/train_ising.py b/tutorials/examples/train_ising.py
index 1ca2c656..878c11cf 100644
--- a/tutorials/examples/train_ising.py
+++ b/tutorials/examples/train_ising.py
@@ -83,8 +83,14 @@ def ising_n_to_ij(L, n):
     # Learning
     visited_terminating_states = env.States.from_batch_shape((0,))
     states_visited = 0
+
     for i in (pbar := tqdm(range(10000))):
-        trajectories = gflownet.sample_trajectories(env, n_samples=8, off_policy=False)
+        trajectories = gflownet.sample_trajectories(
+            env,
+            n=8,
+            save_estimator_outputs=False,
+            save_logprobs=True,
+        )
         training_samples = gflownet.to_training_samples(trajectories)
         optimizer.zero_grad()
         loss = gflownet.loss(env, training_samples)
diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py
index 6ce7fde6..c43115f9 100644
--- a/tutorials/examples/train_line.py
+++ b/tutorials/examples/train_line.py
@@ -227,7 +227,7 @@ def train(
         # Off Policy Sampling.
         trajectories = gflownet.sample_trajectories(
             env,
-            n_samples=batch_size,
+            n=batch_size,
             save_estimator_outputs=True,
             save_logprobs=False,
             scale_factor=scale_schedule[iteration],  # Off policy kwargs.
@@ -292,7 +292,7 @@ def train(
         policy_std_max=policy_std_max,
     )
     pb = StepEstimator(environment, pb_module, backward=True)
-    gflownet = TBGFlowNet(pf=pf, pb=pb, init_logZ=0.0)
+    gflownet = TBGFlowNet(pf=pf, pb=pb, logZ=0.0)
 
     gflownet = train(
         gflownet,