From 1c4ec37d97cea4ae3e6d8431083725e167cc82f3 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 13 Nov 2024 15:38:29 -0500 Subject: [PATCH 1/7] Scalar estimators allow for the reduction over many output values (i.e., the output of the nn.Module does not need to be a scalar, because the Estimator will apply a reduction to the final output if required). --- src/gfn/modules.py | 108 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 99 insertions(+), 9 deletions(-) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 4351b462..3dfdc89b 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -134,9 +134,71 @@ def to_probability_distribution( class ScalarEstimator(GFNModule): + r"""Class for estimating scalars such as LogZ. + + Training a GFlowNet requires sometimes requires the estimation of precise scalar + values, such as the partition function of flows on the DAG. This Estimator is + designed for those cases. + + The function approximator used for `module` need not directly output a scalar. If + it does not, `reduction` will be used to aggregate the outputs of the module into + a single scalar. + + Attributes: + preprocessor: Preprocessor object that transforms raw States objects to tensors + that can be used as input to the module. Optional, defaults to + `IdentityPreprocessor`. + module: The module to use. If the module is a Tabular module (from + `gfn.utils.modules`), then the environment preprocessor needs to be an + `EnumPreprocessor`. + preprocessor: Preprocessor from the environment. + _output_dim_is_checked: Flag for tracking whether the output dimenions of + the states (after being preprocessed and transformed by the modules) have + been verified. + """ + + def __init__( + self, + module: nn.Module, + preprocessor: Preprocessor | None = None, + is_backward: bool = False, + reduction: str = "mean", + ): + super().__init__(module, preprocessor, is_backward) + reduction_fxns = { + "mean": torch.mean, + "sum": torch.sum, + "prod": torch.prod, + } + assert reduction in reduction_fxns + self.reduction_fxn = reduction_fxns[reduction] + def expected_output_dim(self) -> int: return 1 + def forward(self, input: States | torch.Tensor) -> torch.Tensor: + """Forward pass of the module. + + Args: + input: The input to the module, as states or a tensor. + + Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). + """ + if isinstance(input, States): + input = self.preprocessor(input) + + out = self.module(input) + + # Ensures estimator outputs are always scalar. + if out.shape[-1] != 1: + out = self.reduction_fxn(out, -1) + + if not self._output_dim_is_checked: + self.check_output_dim(out) + self._output_dim_is_checked = True + + return out + class DiscretePolicyEstimator(GFNModule): r"""Container for forward and backward policy estimators for discrete environments. @@ -290,6 +352,31 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator): + r"""Class for conditionally estimating scalars such as LogZ. + + Training a GFlowNet requires sometimes requires the estimation of precise scalar + values, such as the partition function of flows on the DAG. In the case of a + conditional GFN, the logZ or logF estimate is also conditional. This Estimator is + designed for those cases. + + The function approximator used for `module` need not directly output a scalar. If + it does not, `reduction` will be used to aggregate the outputs of the module into + a single scalar. + + Attributes: + preprocessor: Preprocessor object that transforms raw States objects to tensors + that can be used as input to the module. Optional, defaults to + `IdentityPreprocessor`. + module: The module to use. If the module is a Tabular module (from + `gfn.utils.modules`), then the environment preprocessor needs to be an + `EnumPreprocessor`. + preprocessor: Preprocessor from the environment. + reduction_fxn: the selected torch reduction operation. + _output_dim_is_checked: Flag for tracking whether the output dimensions of + the states (after being preprocessed and transformed by the modules) have + been verified. + """ + def __init__( self, state_module: nn.Module, @@ -297,6 +384,7 @@ def __init__( final_module: nn.Module, preprocessor: Preprocessor | None = None, is_backward: bool = False, + reduction: str = "mean", ): super().__init__( state_module, @@ -306,6 +394,13 @@ def __init__( preprocessor=preprocessor, is_backward=is_backward, ) + reduction_fxns = { + "mean": torch.mean, + "sum": torch.sum, + "prod": torch.prod, + } + assert reduction in reduction_fxns + self.reduction_fxn = reduction_fxns[reduction] def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: """Forward pass of the module. @@ -318,6 +413,10 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: """ out = self._forward_trunk(states, conditioning) + # Ensures estimator outputs are always scalar. + if out.shape[-1] != 1: + out = self.reduction_fxn(out, -1) + if not self._output_dim_is_checked: self.check_output_dim(out) self._output_dim_is_checked = True @@ -333,13 +432,4 @@ def to_probability_distribution( module_output: torch.Tensor, **policy_kwargs: Any, ) -> Distribution: - """Transform the output of the module into a probability distribution. - - Args: - states: The states to use. - module_output: The output of the module as a tensor of shape (*batch_shape, output_dim). - **policy_kwargs: Keyword arguments to modify the distribution. - - Returns a distribution object. - """ raise NotImplementedError From db13637daeb1c1e53ba3ded8111c33dcfb16ba3c Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 13 Nov 2024 15:41:14 -0500 Subject: [PATCH 2/7] black --- src/gfn/gflownet/flow_matching.py | 4 +--- src/gfn/samplers.py | 6 +++++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 38072080..4d2f2354 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -203,9 +203,7 @@ def loss( ) return fm_loss + self.alpha * rm_loss - def to_training_samples( - self, trajectories: Trajectories - ) -> Union[ + def to_training_samples(self, trajectories: Trajectories) -> Union[ Tuple[DiscreteStates, DiscreteStates, torch.Tensor, torch.Tensor], Tuple[DiscreteStates, DiscreteStates, None, None], Tuple[States, States, torch.Tensor, torch.Tensor], diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index eb224fbf..0df9449f 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -35,7 +35,11 @@ def sample_actions( save_estimator_outputs: bool = False, save_logprobs: bool = True, **policy_kwargs: Any, - ) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None,]: + ) -> Tuple[ + Actions, + torch.Tensor | None, + torch.Tensor | None, + ]: """Samples actions from the given states. Args: From be1087cc5d8cd6ac33528e51bd1f11c860f4beca Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 14 Nov 2024 17:44:23 -0500 Subject: [PATCH 3/7] updated docstrings --- src/gfn/modules.py | 72 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 53 insertions(+), 19 deletions(-) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 3dfdc89b..21fe3744 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -10,6 +10,13 @@ from gfn.utils.distributions import UnsqueezedCategorical +REDUCTION_FXNS = { + "mean": torch.mean, + "sum": torch.sum, + "prod": torch.prod, +} + + class GFNModule(ABC, nn.Module): r"""Base class for modules mapping states distributions. @@ -41,9 +48,11 @@ class GFNModule(ABC, nn.Module): `gfn.utils.modules`), then the environment preprocessor needs to be an `EnumPreprocessor`. preprocessor: Preprocessor from the environment. - _output_dim_is_checked: Flag for tracking whether the output dimenions of + _output_dim_is_checked: Flag for tracking whether the output dimensions of the states (after being preprocessed and transformed by the modules) have been verified. + _is_backward: Flag for tracking whether this estimator is used for predicting + probability distributions over parents. """ def __init__( @@ -52,7 +61,7 @@ def __init__( preprocessor: Preprocessor | None = None, is_backward: bool = False, ) -> None: - """Initalize the FunctionEstimator with an environment and a module. + """Initialize the GFNModule with nn.Module and a preprocessor. Args: module: The module to use. If the module is a Tabular module (from `gfn.utils.modules`), then the environment preprocessor needs to be an @@ -152,9 +161,12 @@ class ScalarEstimator(GFNModule): `gfn.utils.modules`), then the environment preprocessor needs to be an `EnumPreprocessor`. preprocessor: Preprocessor from the environment. - _output_dim_is_checked: Flag for tracking whether the output dimenions of + _output_dim_is_checked: Flag for tracking whether the output dimensions of the states (after being preprocessed and transformed by the modules) have been verified. + _is_backward: Flag for tracking whether this estimator is used for predicting + probability distributions over parents. + reduction_function: String denoting the """ def __init__( @@ -164,14 +176,22 @@ def __init__( is_backward: bool = False, reduction: str = "mean", ): + """Initialize the GFNModule with a scalar output. + Args: + module: The module to use. If the module is a Tabular module (from + `gfn.utils.modules`), then the environment preprocessor needs to be an + `EnumPreprocessor`. + preprocessor: Preprocessor object. + is_backward: Flags estimators of probability distributions over parents. + reduction: str name of the one of the REDUCTION_FXNS keys: {} + """.format( + list(REDUCTION_FXNS.keys()) + ) super().__init__(module, preprocessor, is_backward) - reduction_fxns = { - "mean": torch.mean, - "sum": torch.sum, - "prod": torch.prod, - } - assert reduction in reduction_fxns - self.reduction_fxn = reduction_fxns[reduction] + assert reduction in REDUCTION_FXNS, "reduction function not one of {}".format( + REDUCTION_FXNS.keys() + ) + self.reduction_fxn = REDUCTION_FXNS[reduction] def expected_output_dim(self) -> int: return 1 @@ -359,8 +379,8 @@ class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator): conditional GFN, the logZ or logF estimate is also conditional. This Estimator is designed for those cases. - The function approximator used for `module` need not directly output a scalar. If - it does not, `reduction` will be used to aggregate the outputs of the module into + The function approximator used for `final_module` need not directly output a scalar. + If it does not, `reduction` will be used to aggregate the outputs of the module into a single scalar. Attributes: @@ -375,6 +395,9 @@ class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator): _output_dim_is_checked: Flag for tracking whether the output dimensions of the states (after being preprocessed and transformed by the modules) have been verified. + _is_backward: Flag for tracking whether this estimator is used for predicting + probability distributions over parents. + reduction_function: String denoting the """ def __init__( @@ -386,6 +409,20 @@ def __init__( is_backward: bool = False, reduction: str = "mean", ): + """Initialize a conditional GFNModule with a scalar output. + Args: + state_module: The module to use for state representations. If the module is + a Tabular module (from `gfn.utils.modules`), then the environment + preprocessor needs to be an `EnumPreprocessor`. + conditioning_module: The module to use for conditioning representations. + final_module: The module to use for computing the final output. + preprocessor: Preprocessor object. + is_backward: Flags estimators of probability distributions over parents. + reduction: str name of the one of the REDUCTION_FXNS keys: {} + """.format( + list(REDUCTION_FXNS.keys()) + ) + super().__init__( state_module, conditioning_module, @@ -394,13 +431,10 @@ def __init__( preprocessor=preprocessor, is_backward=is_backward, ) - reduction_fxns = { - "mean": torch.mean, - "sum": torch.sum, - "prod": torch.prod, - } - assert reduction in reduction_fxns - self.reduction_fxn = reduction_fxns[reduction] + assert reduction in REDUCTION_FXNS, "reduction function not one of {}".format( + REDUCTION_FXNS.keys() + ) + self.reduction_fxn = REDUCTION_FXNS[reduction] def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: """Forward pass of the module. From 0adf8084f2a4f0e660658dafe1cbb4600629dfea Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 14 Nov 2024 17:54:53 -0500 Subject: [PATCH 4/7] isort --- src/gfn/modules.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 21fe3744..5dd36d5a 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -9,7 +9,6 @@ from gfn.states import DiscreteStates, States from gfn.utils.distributions import UnsqueezedCategorical - REDUCTION_FXNS = { "mean": torch.mean, "sum": torch.sum, From 98579e81804f1825978eaf3ef2a93faaeb219a40 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 14 Nov 2024 18:16:31 -0500 Subject: [PATCH 5/7] isort/black --- src/gfn/samplers.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 0df9449f..819620f0 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -35,11 +35,7 @@ def sample_actions( save_estimator_outputs: bool = False, save_logprobs: bool = True, **policy_kwargs: Any, - ) -> Tuple[ - Actions, - torch.Tensor | None, - torch.Tensor | None, - ]: + ) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None]: """Samples actions from the given states. Args: From 9f0140edcbcc214b8ddd0588dc1219ee6767aeb5 Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 14 Nov 2024 19:17:18 -0500 Subject: [PATCH 6/7] black --- src/gfn/gflownet/flow_matching.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 4d2f2354..38072080 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -203,7 +203,9 @@ def loss( ) return fm_loss + self.alpha * rm_loss - def to_training_samples(self, trajectories: Trajectories) -> Union[ + def to_training_samples( + self, trajectories: Trajectories + ) -> Union[ Tuple[DiscreteStates, DiscreteStates, torch.Tensor, torch.Tensor], Tuple[DiscreteStates, DiscreteStates, None, None], Tuple[States, States, torch.Tensor, torch.Tensor], From 90e14e2bcfa037e6d1aeb091799c2fa1b9a7185d Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Thu, 14 Nov 2024 19:33:14 -0500 Subject: [PATCH 7/7] updated docstrings --- src/gfn/modules.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 5dd36d5a..bf649837 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -142,7 +142,7 @@ def to_probability_distribution( class ScalarEstimator(GFNModule): - r"""Class for estimating scalars such as LogZ. + r"""Class for estimating scalars such as LogZ or state flow functions of DB/SubTB. Training a GFlowNet requires sometimes requires the estimation of precise scalar values, such as the partition function of flows on the DAG. This Estimator is @@ -371,7 +371,7 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator): - r"""Class for conditionally estimating scalars such as LogZ. + r"""Class for conditionally estimating scalars (LogZ, DB/SubTB state logF). Training a GFlowNet requires sometimes requires the estimation of precise scalar values, such as the partition function of flows on the DAG. In the case of a