diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 4351b462..bf649837 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -9,6 +9,12 @@ from gfn.states import DiscreteStates, States from gfn.utils.distributions import UnsqueezedCategorical +REDUCTION_FXNS = { + "mean": torch.mean, + "sum": torch.sum, + "prod": torch.prod, +} + class GFNModule(ABC, nn.Module): r"""Base class for modules mapping states distributions. @@ -41,9 +47,11 @@ class GFNModule(ABC, nn.Module): `gfn.utils.modules`), then the environment preprocessor needs to be an `EnumPreprocessor`. preprocessor: Preprocessor from the environment. - _output_dim_is_checked: Flag for tracking whether the output dimenions of + _output_dim_is_checked: Flag for tracking whether the output dimensions of the states (after being preprocessed and transformed by the modules) have been verified. + _is_backward: Flag for tracking whether this estimator is used for predicting + probability distributions over parents. """ def __init__( @@ -52,7 +60,7 @@ def __init__( preprocessor: Preprocessor | None = None, is_backward: bool = False, ) -> None: - """Initalize the FunctionEstimator with an environment and a module. + """Initialize the GFNModule with nn.Module and a preprocessor. Args: module: The module to use. If the module is a Tabular module (from `gfn.utils.modules`), then the environment preprocessor needs to be an @@ -134,9 +142,82 @@ def to_probability_distribution( class ScalarEstimator(GFNModule): + r"""Class for estimating scalars such as LogZ or state flow functions of DB/SubTB. + + Training a GFlowNet requires sometimes requires the estimation of precise scalar + values, such as the partition function of flows on the DAG. This Estimator is + designed for those cases. + + The function approximator used for `module` need not directly output a scalar. If + it does not, `reduction` will be used to aggregate the outputs of the module into + a single scalar. + + Attributes: + preprocessor: Preprocessor object that transforms raw States objects to tensors + that can be used as input to the module. Optional, defaults to + `IdentityPreprocessor`. + module: The module to use. If the module is a Tabular module (from + `gfn.utils.modules`), then the environment preprocessor needs to be an + `EnumPreprocessor`. + preprocessor: Preprocessor from the environment. + _output_dim_is_checked: Flag for tracking whether the output dimensions of + the states (after being preprocessed and transformed by the modules) have + been verified. + _is_backward: Flag for tracking whether this estimator is used for predicting + probability distributions over parents. + reduction_function: String denoting the + """ + + def __init__( + self, + module: nn.Module, + preprocessor: Preprocessor | None = None, + is_backward: bool = False, + reduction: str = "mean", + ): + """Initialize the GFNModule with a scalar output. + Args: + module: The module to use. If the module is a Tabular module (from + `gfn.utils.modules`), then the environment preprocessor needs to be an + `EnumPreprocessor`. + preprocessor: Preprocessor object. + is_backward: Flags estimators of probability distributions over parents. + reduction: str name of the one of the REDUCTION_FXNS keys: {} + """.format( + list(REDUCTION_FXNS.keys()) + ) + super().__init__(module, preprocessor, is_backward) + assert reduction in REDUCTION_FXNS, "reduction function not one of {}".format( + REDUCTION_FXNS.keys() + ) + self.reduction_fxn = REDUCTION_FXNS[reduction] + def expected_output_dim(self) -> int: return 1 + def forward(self, input: States | torch.Tensor) -> torch.Tensor: + """Forward pass of the module. + + Args: + input: The input to the module, as states or a tensor. + + Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). + """ + if isinstance(input, States): + input = self.preprocessor(input) + + out = self.module(input) + + # Ensures estimator outputs are always scalar. + if out.shape[-1] != 1: + out = self.reduction_fxn(out, -1) + + if not self._output_dim_is_checked: + self.check_output_dim(out) + self._output_dim_is_checked = True + + return out + class DiscretePolicyEstimator(GFNModule): r"""Container for forward and backward policy estimators for discrete environments. @@ -290,6 +371,34 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator): + r"""Class for conditionally estimating scalars (LogZ, DB/SubTB state logF). + + Training a GFlowNet requires sometimes requires the estimation of precise scalar + values, such as the partition function of flows on the DAG. In the case of a + conditional GFN, the logZ or logF estimate is also conditional. This Estimator is + designed for those cases. + + The function approximator used for `final_module` need not directly output a scalar. + If it does not, `reduction` will be used to aggregate the outputs of the module into + a single scalar. + + Attributes: + preprocessor: Preprocessor object that transforms raw States objects to tensors + that can be used as input to the module. Optional, defaults to + `IdentityPreprocessor`. + module: The module to use. If the module is a Tabular module (from + `gfn.utils.modules`), then the environment preprocessor needs to be an + `EnumPreprocessor`. + preprocessor: Preprocessor from the environment. + reduction_fxn: the selected torch reduction operation. + _output_dim_is_checked: Flag for tracking whether the output dimensions of + the states (after being preprocessed and transformed by the modules) have + been verified. + _is_backward: Flag for tracking whether this estimator is used for predicting + probability distributions over parents. + reduction_function: String denoting the + """ + def __init__( self, state_module: nn.Module, @@ -297,7 +406,22 @@ def __init__( final_module: nn.Module, preprocessor: Preprocessor | None = None, is_backward: bool = False, + reduction: str = "mean", ): + """Initialize a conditional GFNModule with a scalar output. + Args: + state_module: The module to use for state representations. If the module is + a Tabular module (from `gfn.utils.modules`), then the environment + preprocessor needs to be an `EnumPreprocessor`. + conditioning_module: The module to use for conditioning representations. + final_module: The module to use for computing the final output. + preprocessor: Preprocessor object. + is_backward: Flags estimators of probability distributions over parents. + reduction: str name of the one of the REDUCTION_FXNS keys: {} + """.format( + list(REDUCTION_FXNS.keys()) + ) + super().__init__( state_module, conditioning_module, @@ -306,6 +430,10 @@ def __init__( preprocessor=preprocessor, is_backward=is_backward, ) + assert reduction in REDUCTION_FXNS, "reduction function not one of {}".format( + REDUCTION_FXNS.keys() + ) + self.reduction_fxn = REDUCTION_FXNS[reduction] def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: """Forward pass of the module. @@ -318,6 +446,10 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: """ out = self._forward_trunk(states, conditioning) + # Ensures estimator outputs are always scalar. + if out.shape[-1] != 1: + out = self.reduction_fxn(out, -1) + if not self._output_dim_is_checked: self.check_output_dim(out) self._output_dim_is_checked = True @@ -333,13 +465,4 @@ def to_probability_distribution( module_output: torch.Tensor, **policy_kwargs: Any, ) -> Distribution: - """Transform the output of the module into a probability distribution. - - Args: - states: The states to use. - module_output: The output of the module as a tensor of shape (*batch_shape, output_dim). - **policy_kwargs: Keyword arguments to modify the distribution. - - Returns a distribution object. - """ raise NotImplementedError diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index eb224fbf..819620f0 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -35,7 +35,7 @@ def sample_actions( save_estimator_outputs: bool = False, save_logprobs: bool = True, **policy_kwargs: Any, - ) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None,]: + ) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None]: """Samples actions from the given states. Args: