diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 4351b462..3dfdc89b 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -134,9 +134,71 @@ def to_probability_distribution( class ScalarEstimator(GFNModule): + r"""Class for estimating scalars such as LogZ. + + Training a GFlowNet requires sometimes requires the estimation of precise scalar + values, such as the partition function of flows on the DAG. This Estimator is + designed for those cases. + + The function approximator used for `module` need not directly output a scalar. If + it does not, `reduction` will be used to aggregate the outputs of the module into + a single scalar. + + Attributes: + preprocessor: Preprocessor object that transforms raw States objects to tensors + that can be used as input to the module. Optional, defaults to + `IdentityPreprocessor`. + module: The module to use. If the module is a Tabular module (from + `gfn.utils.modules`), then the environment preprocessor needs to be an + `EnumPreprocessor`. + preprocessor: Preprocessor from the environment. + _output_dim_is_checked: Flag for tracking whether the output dimenions of + the states (after being preprocessed and transformed by the modules) have + been verified. + """ + + def __init__( + self, + module: nn.Module, + preprocessor: Preprocessor | None = None, + is_backward: bool = False, + reduction: str = "mean", + ): + super().__init__(module, preprocessor, is_backward) + reduction_fxns = { + "mean": torch.mean, + "sum": torch.sum, + "prod": torch.prod, + } + assert reduction in reduction_fxns + self.reduction_fxn = reduction_fxns[reduction] + def expected_output_dim(self) -> int: return 1 + def forward(self, input: States | torch.Tensor) -> torch.Tensor: + """Forward pass of the module. + + Args: + input: The input to the module, as states or a tensor. + + Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). + """ + if isinstance(input, States): + input = self.preprocessor(input) + + out = self.module(input) + + # Ensures estimator outputs are always scalar. + if out.shape[-1] != 1: + out = self.reduction_fxn(out, -1) + + if not self._output_dim_is_checked: + self.check_output_dim(out) + self._output_dim_is_checked = True + + return out + class DiscretePolicyEstimator(GFNModule): r"""Container for forward and backward policy estimators for discrete environments. @@ -290,6 +352,31 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator): + r"""Class for conditionally estimating scalars such as LogZ. + + Training a GFlowNet requires sometimes requires the estimation of precise scalar + values, such as the partition function of flows on the DAG. In the case of a + conditional GFN, the logZ or logF estimate is also conditional. This Estimator is + designed for those cases. + + The function approximator used for `module` need not directly output a scalar. If + it does not, `reduction` will be used to aggregate the outputs of the module into + a single scalar. + + Attributes: + preprocessor: Preprocessor object that transforms raw States objects to tensors + that can be used as input to the module. Optional, defaults to + `IdentityPreprocessor`. + module: The module to use. If the module is a Tabular module (from + `gfn.utils.modules`), then the environment preprocessor needs to be an + `EnumPreprocessor`. + preprocessor: Preprocessor from the environment. + reduction_fxn: the selected torch reduction operation. + _output_dim_is_checked: Flag for tracking whether the output dimensions of + the states (after being preprocessed and transformed by the modules) have + been verified. + """ + def __init__( self, state_module: nn.Module, @@ -297,6 +384,7 @@ def __init__( final_module: nn.Module, preprocessor: Preprocessor | None = None, is_backward: bool = False, + reduction: str = "mean", ): super().__init__( state_module, @@ -306,6 +394,13 @@ def __init__( preprocessor=preprocessor, is_backward=is_backward, ) + reduction_fxns = { + "mean": torch.mean, + "sum": torch.sum, + "prod": torch.prod, + } + assert reduction in reduction_fxns + self.reduction_fxn = reduction_fxns[reduction] def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: """Forward pass of the module. @@ -318,6 +413,10 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: """ out = self._forward_trunk(states, conditioning) + # Ensures estimator outputs are always scalar. + if out.shape[-1] != 1: + out = self.reduction_fxn(out, -1) + if not self._output_dim_is_checked: self.check_output_dim(out) self._output_dim_is_checked = True @@ -333,13 +432,4 @@ def to_probability_distribution( module_output: torch.Tensor, **policy_kwargs: Any, ) -> Distribution: - """Transform the output of the module into a probability distribution. - - Args: - states: The states to use. - module_output: The output of the module as a tensor of shape (*batch_shape, output_dim). - **policy_kwargs: Keyword arguments to modify the distribution. - - Returns a distribution object. - """ raise NotImplementedError