From 1c4ec37d97cea4ae3e6d8431083725e167cc82f3 Mon Sep 17 00:00:00 2001
From: Joseph Viviano <joseph@viviano.ca>
Date: Wed, 13 Nov 2024 15:38:29 -0500
Subject: [PATCH 1/7] Scalar estimators allow for the reduction over many
 output values (i.e., the output of the nn.Module does not need to be a
 scalar, because the Estimator will apply a reduction to the final output if
 required).

---
 src/gfn/modules.py | 108 +++++++++++++++++++++++++++++++++++++++++----
 1 file changed, 99 insertions(+), 9 deletions(-)

diff --git a/src/gfn/modules.py b/src/gfn/modules.py
index 4351b462..3dfdc89b 100644
--- a/src/gfn/modules.py
+++ b/src/gfn/modules.py
@@ -134,9 +134,71 @@ def to_probability_distribution(
 
 
 class ScalarEstimator(GFNModule):
+    r"""Class for estimating scalars such as LogZ.
+
+    Training a GFlowNet requires sometimes requires the estimation of precise scalar
+    values, such as the partition function of flows on the DAG. This Estimator is
+    designed for those cases.
+
+    The function approximator used for `module` need not directly output a scalar. If
+    it does not, `reduction` will be used to aggregate the outputs of the module into
+    a single scalar.
+
+    Attributes:
+        preprocessor: Preprocessor object that transforms raw States objects to tensors
+            that can be used as input to the module. Optional, defaults to
+            `IdentityPreprocessor`.
+        module: The module to use. If the module is a Tabular module (from
+            `gfn.utils.modules`), then the environment preprocessor needs to be an
+            `EnumPreprocessor`.
+        preprocessor: Preprocessor from the environment.
+        _output_dim_is_checked: Flag for tracking whether the output dimenions of
+            the states (after being preprocessed and transformed by the modules) have
+            been verified.
+    """
+
+    def __init__(
+        self,
+        module: nn.Module,
+        preprocessor: Preprocessor | None = None,
+        is_backward: bool = False,
+        reduction: str = "mean",
+    ):
+        super().__init__(module, preprocessor, is_backward)
+        reduction_fxns = {
+            "mean": torch.mean,
+            "sum": torch.sum,
+            "prod": torch.prod,
+        }
+        assert reduction in reduction_fxns
+        self.reduction_fxn = reduction_fxns[reduction]
+
     def expected_output_dim(self) -> int:
         return 1
 
+    def forward(self, input: States | torch.Tensor) -> torch.Tensor:
+        """Forward pass of the module.
+
+        Args:
+            input: The input to the module, as states or a tensor.
+
+        Returns the output of the module, as a tensor of shape (*batch_shape, output_dim).
+        """
+        if isinstance(input, States):
+            input = self.preprocessor(input)
+
+        out = self.module(input)
+
+        # Ensures estimator outputs are always scalar.
+        if out.shape[-1] != 1:
+            out = self.reduction_fxn(out, -1)
+
+        if not self._output_dim_is_checked:
+            self.check_output_dim(out)
+            self._output_dim_is_checked = True
+
+        return out
+
 
 class DiscretePolicyEstimator(GFNModule):
     r"""Container for forward and backward policy estimators for discrete environments.
@@ -290,6 +352,31 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:
 
 
 class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator):
+    r"""Class for conditionally estimating scalars such as LogZ.
+
+    Training a GFlowNet requires sometimes requires the estimation of precise scalar
+    values, such as the partition function of flows on the DAG. In the case of a
+    conditional GFN, the logZ or logF estimate is also conditional. This Estimator is
+    designed for those cases.
+
+    The function approximator used for `module` need not directly output a scalar. If
+    it does not, `reduction` will be used to aggregate the outputs of the module into
+    a single scalar.
+
+    Attributes:
+        preprocessor: Preprocessor object that transforms raw States objects to tensors
+            that can be used as input to the module. Optional, defaults to
+            `IdentityPreprocessor`.
+        module: The module to use. If the module is a Tabular module (from
+            `gfn.utils.modules`), then the environment preprocessor needs to be an
+            `EnumPreprocessor`.
+        preprocessor: Preprocessor from the environment.
+        reduction_fxn: the selected torch reduction operation.
+        _output_dim_is_checked: Flag for tracking whether the output dimensions of
+            the states (after being preprocessed and transformed by the modules) have
+            been verified.
+    """
+
     def __init__(
         self,
         state_module: nn.Module,
@@ -297,6 +384,7 @@ def __init__(
         final_module: nn.Module,
         preprocessor: Preprocessor | None = None,
         is_backward: bool = False,
+        reduction: str = "mean",
     ):
         super().__init__(
             state_module,
@@ -306,6 +394,13 @@ def __init__(
             preprocessor=preprocessor,
             is_backward=is_backward,
         )
+        reduction_fxns = {
+            "mean": torch.mean,
+            "sum": torch.sum,
+            "prod": torch.prod,
+        }
+        assert reduction in reduction_fxns
+        self.reduction_fxn = reduction_fxns[reduction]
 
     def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:
         """Forward pass of the module.
@@ -318,6 +413,10 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:
         """
         out = self._forward_trunk(states, conditioning)
 
+        # Ensures estimator outputs are always scalar.
+        if out.shape[-1] != 1:
+            out = self.reduction_fxn(out, -1)
+
         if not self._output_dim_is_checked:
             self.check_output_dim(out)
             self._output_dim_is_checked = True
@@ -333,13 +432,4 @@ def to_probability_distribution(
         module_output: torch.Tensor,
         **policy_kwargs: Any,
     ) -> Distribution:
-        """Transform the output of the module into a probability distribution.
-
-        Args:
-            states: The states to use.
-            module_output: The output of the module as a tensor of shape (*batch_shape, output_dim).
-            **policy_kwargs: Keyword arguments to modify the distribution.
-
-        Returns a distribution object.
-        """
         raise NotImplementedError

From db13637daeb1c1e53ba3ded8111c33dcfb16ba3c Mon Sep 17 00:00:00 2001
From: Joseph Viviano <joseph@viviano.ca>
Date: Wed, 13 Nov 2024 15:41:14 -0500
Subject: [PATCH 2/7] black

---
 src/gfn/gflownet/flow_matching.py | 4 +---
 src/gfn/samplers.py               | 6 +++++-
 2 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py
index 38072080..4d2f2354 100644
--- a/src/gfn/gflownet/flow_matching.py
+++ b/src/gfn/gflownet/flow_matching.py
@@ -203,9 +203,7 @@ def loss(
         )
         return fm_loss + self.alpha * rm_loss
 
-    def to_training_samples(
-        self, trajectories: Trajectories
-    ) -> Union[
+    def to_training_samples(self, trajectories: Trajectories) -> Union[
         Tuple[DiscreteStates, DiscreteStates, torch.Tensor, torch.Tensor],
         Tuple[DiscreteStates, DiscreteStates, None, None],
         Tuple[States, States, torch.Tensor, torch.Tensor],
diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py
index eb224fbf..0df9449f 100644
--- a/src/gfn/samplers.py
+++ b/src/gfn/samplers.py
@@ -35,7 +35,11 @@ def sample_actions(
         save_estimator_outputs: bool = False,
         save_logprobs: bool = True,
         **policy_kwargs: Any,
-    ) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None,]:
+    ) -> Tuple[
+        Actions,
+        torch.Tensor | None,
+        torch.Tensor | None,
+    ]:
         """Samples actions from the given states.
 
         Args:

From be1087cc5d8cd6ac33528e51bd1f11c860f4beca Mon Sep 17 00:00:00 2001
From: Joseph Viviano <joseph@viviano.ca>
Date: Thu, 14 Nov 2024 17:44:23 -0500
Subject: [PATCH 3/7] updated docstrings

---
 src/gfn/modules.py | 72 ++++++++++++++++++++++++++++++++++------------
 1 file changed, 53 insertions(+), 19 deletions(-)

diff --git a/src/gfn/modules.py b/src/gfn/modules.py
index 3dfdc89b..21fe3744 100644
--- a/src/gfn/modules.py
+++ b/src/gfn/modules.py
@@ -10,6 +10,13 @@
 from gfn.utils.distributions import UnsqueezedCategorical
 
 
+REDUCTION_FXNS = {
+    "mean": torch.mean,
+    "sum": torch.sum,
+    "prod": torch.prod,
+}
+
+
 class GFNModule(ABC, nn.Module):
     r"""Base class for modules mapping states distributions.
 
@@ -41,9 +48,11 @@ class GFNModule(ABC, nn.Module):
             `gfn.utils.modules`), then the environment preprocessor needs to be an
             `EnumPreprocessor`.
         preprocessor: Preprocessor from the environment.
-        _output_dim_is_checked: Flag for tracking whether the output dimenions of
+        _output_dim_is_checked: Flag for tracking whether the output dimensions of
             the states (after being preprocessed and transformed by the modules) have
             been verified.
+        _is_backward: Flag for tracking whether this estimator is used for predicting
+            probability distributions over parents.
     """
 
     def __init__(
@@ -52,7 +61,7 @@ def __init__(
         preprocessor: Preprocessor | None = None,
         is_backward: bool = False,
     ) -> None:
-        """Initalize the FunctionEstimator with an environment and a module.
+        """Initialize the GFNModule with nn.Module and a preprocessor.
         Args:
             module: The module to use. If the module is a Tabular module (from
                 `gfn.utils.modules`), then the environment preprocessor needs to be an
@@ -152,9 +161,12 @@ class ScalarEstimator(GFNModule):
             `gfn.utils.modules`), then the environment preprocessor needs to be an
             `EnumPreprocessor`.
         preprocessor: Preprocessor from the environment.
-        _output_dim_is_checked: Flag for tracking whether the output dimenions of
+        _output_dim_is_checked: Flag for tracking whether the output dimensions of
             the states (after being preprocessed and transformed by the modules) have
             been verified.
+        _is_backward: Flag for tracking whether this estimator is used for predicting
+            probability distributions over parents.
+            reduction_function: String denoting the
     """
 
     def __init__(
@@ -164,14 +176,22 @@ def __init__(
         is_backward: bool = False,
         reduction: str = "mean",
     ):
+        """Initialize the GFNModule with a scalar output.
+        Args:
+            module: The module to use. If the module is a Tabular module (from
+                `gfn.utils.modules`), then the environment preprocessor needs to be an
+                `EnumPreprocessor`.
+            preprocessor: Preprocessor object.
+            is_backward: Flags estimators of probability distributions over parents.
+            reduction: str name of the one of the REDUCTION_FXNS keys: {}
+        """.format(
+            list(REDUCTION_FXNS.keys())
+        )
         super().__init__(module, preprocessor, is_backward)
-        reduction_fxns = {
-            "mean": torch.mean,
-            "sum": torch.sum,
-            "prod": torch.prod,
-        }
-        assert reduction in reduction_fxns
-        self.reduction_fxn = reduction_fxns[reduction]
+        assert reduction in REDUCTION_FXNS, "reduction function not one of {}".format(
+            REDUCTION_FXNS.keys()
+        )
+        self.reduction_fxn = REDUCTION_FXNS[reduction]
 
     def expected_output_dim(self) -> int:
         return 1
@@ -359,8 +379,8 @@ class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator):
     conditional GFN, the logZ or logF estimate is also conditional. This Estimator is
     designed for those cases.
 
-    The function approximator used for `module` need not directly output a scalar. If
-    it does not, `reduction` will be used to aggregate the outputs of the module into
+    The function approximator used for `final_module` need not directly output a scalar.
+    If it does not, `reduction` will be used to aggregate the outputs of the module into
     a single scalar.
 
     Attributes:
@@ -375,6 +395,9 @@ class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator):
         _output_dim_is_checked: Flag for tracking whether the output dimensions of
             the states (after being preprocessed and transformed by the modules) have
             been verified.
+        _is_backward: Flag for tracking whether this estimator is used for predicting
+            probability distributions over parents.
+            reduction_function: String denoting the
     """
 
     def __init__(
@@ -386,6 +409,20 @@ def __init__(
         is_backward: bool = False,
         reduction: str = "mean",
     ):
+        """Initialize a conditional GFNModule with a scalar output.
+        Args:
+            state_module: The module to use for state representations. If the module is
+                a Tabular module (from `gfn.utils.modules`), then the environment
+                preprocessor needs to be an `EnumPreprocessor`.
+            conditioning_module: The module to use for conditioning representations.
+            final_module: The module to use for computing the final output.
+            preprocessor: Preprocessor object.
+            is_backward: Flags estimators of probability distributions over parents.
+            reduction: str name of the one of the REDUCTION_FXNS keys: {}
+        """.format(
+            list(REDUCTION_FXNS.keys())
+        )
+
         super().__init__(
             state_module,
             conditioning_module,
@@ -394,13 +431,10 @@ def __init__(
             preprocessor=preprocessor,
             is_backward=is_backward,
         )
-        reduction_fxns = {
-            "mean": torch.mean,
-            "sum": torch.sum,
-            "prod": torch.prod,
-        }
-        assert reduction in reduction_fxns
-        self.reduction_fxn = reduction_fxns[reduction]
+        assert reduction in REDUCTION_FXNS, "reduction function not one of {}".format(
+            REDUCTION_FXNS.keys()
+        )
+        self.reduction_fxn = REDUCTION_FXNS[reduction]
 
     def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:
         """Forward pass of the module.

From 0adf8084f2a4f0e660658dafe1cbb4600629dfea Mon Sep 17 00:00:00 2001
From: Joseph Viviano <joseph@viviano.ca>
Date: Thu, 14 Nov 2024 17:54:53 -0500
Subject: [PATCH 4/7] isort

---
 src/gfn/modules.py | 1 -
 1 file changed, 1 deletion(-)

diff --git a/src/gfn/modules.py b/src/gfn/modules.py
index 21fe3744..5dd36d5a 100644
--- a/src/gfn/modules.py
+++ b/src/gfn/modules.py
@@ -9,7 +9,6 @@
 from gfn.states import DiscreteStates, States
 from gfn.utils.distributions import UnsqueezedCategorical
 
-
 REDUCTION_FXNS = {
     "mean": torch.mean,
     "sum": torch.sum,

From 98579e81804f1825978eaf3ef2a93faaeb219a40 Mon Sep 17 00:00:00 2001
From: Joseph Viviano <joseph@viviano.ca>
Date: Thu, 14 Nov 2024 18:16:31 -0500
Subject: [PATCH 5/7] isort/black

---
 src/gfn/samplers.py | 6 +-----
 1 file changed, 1 insertion(+), 5 deletions(-)

diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py
index 0df9449f..819620f0 100644
--- a/src/gfn/samplers.py
+++ b/src/gfn/samplers.py
@@ -35,11 +35,7 @@ def sample_actions(
         save_estimator_outputs: bool = False,
         save_logprobs: bool = True,
         **policy_kwargs: Any,
-    ) -> Tuple[
-        Actions,
-        torch.Tensor | None,
-        torch.Tensor | None,
-    ]:
+    ) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None]:
         """Samples actions from the given states.
 
         Args:

From 9f0140edcbcc214b8ddd0588dc1219ee6767aeb5 Mon Sep 17 00:00:00 2001
From: Joseph Viviano <joseph@viviano.ca>
Date: Thu, 14 Nov 2024 19:17:18 -0500
Subject: [PATCH 6/7] black

---
 src/gfn/gflownet/flow_matching.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py
index 4d2f2354..38072080 100644
--- a/src/gfn/gflownet/flow_matching.py
+++ b/src/gfn/gflownet/flow_matching.py
@@ -203,7 +203,9 @@ def loss(
         )
         return fm_loss + self.alpha * rm_loss
 
-    def to_training_samples(self, trajectories: Trajectories) -> Union[
+    def to_training_samples(
+        self, trajectories: Trajectories
+    ) -> Union[
         Tuple[DiscreteStates, DiscreteStates, torch.Tensor, torch.Tensor],
         Tuple[DiscreteStates, DiscreteStates, None, None],
         Tuple[States, States, torch.Tensor, torch.Tensor],

From 90e14e2bcfa037e6d1aeb091799c2fa1b9a7185d Mon Sep 17 00:00:00 2001
From: Joseph Viviano <joseph@viviano.ca>
Date: Thu, 14 Nov 2024 19:33:14 -0500
Subject: [PATCH 7/7] updated docstrings

---
 src/gfn/modules.py | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/src/gfn/modules.py b/src/gfn/modules.py
index 5dd36d5a..bf649837 100644
--- a/src/gfn/modules.py
+++ b/src/gfn/modules.py
@@ -142,7 +142,7 @@ def to_probability_distribution(
 
 
 class ScalarEstimator(GFNModule):
-    r"""Class for estimating scalars such as LogZ.
+    r"""Class for estimating scalars such as LogZ or state flow functions of DB/SubTB.
 
     Training a GFlowNet requires sometimes requires the estimation of precise scalar
     values, such as the partition function of flows on the DAG. This Estimator is
@@ -371,7 +371,7 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:
 
 
 class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator):
-    r"""Class for conditionally estimating scalars such as LogZ.
+    r"""Class for conditionally estimating scalars (LogZ,  DB/SubTB state logF).
 
     Training a GFlowNet requires sometimes requires the estimation of precise scalar
     values, such as the partition function of flows on the DAG. In the case of a