-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Scalar estimators allow for the reduction over many output values (i.… #215
Changes from 3 commits
1c4ec37
db13637
be1087c
0adf808
98579e8
9f0140e
90e14e2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,13 @@ | |
from gfn.utils.distributions import UnsqueezedCategorical | ||
|
||
|
||
REDUCTION_FXNS = { | ||
"mean": torch.mean, | ||
"sum": torch.sum, | ||
"prod": torch.prod, | ||
} | ||
|
||
|
||
class GFNModule(ABC, nn.Module): | ||
r"""Base class for modules mapping states distributions. | ||
|
||
|
@@ -41,9 +48,11 @@ class GFNModule(ABC, nn.Module): | |
`gfn.utils.modules`), then the environment preprocessor needs to be an | ||
`EnumPreprocessor`. | ||
preprocessor: Preprocessor from the environment. | ||
_output_dim_is_checked: Flag for tracking whether the output dimenions of | ||
_output_dim_is_checked: Flag for tracking whether the output dimensions of | ||
the states (after being preprocessed and transformed by the modules) have | ||
been verified. | ||
_is_backward: Flag for tracking whether this estimator is used for predicting | ||
probability distributions over parents. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -52,7 +61,7 @@ def __init__( | |
preprocessor: Preprocessor | None = None, | ||
is_backward: bool = False, | ||
) -> None: | ||
"""Initalize the FunctionEstimator with an environment and a module. | ||
"""Initialize the GFNModule with nn.Module and a preprocessor. | ||
Args: | ||
module: The module to use. If the module is a Tabular module (from | ||
`gfn.utils.modules`), then the environment preprocessor needs to be an | ||
|
@@ -134,9 +143,82 @@ def to_probability_distribution( | |
|
||
|
||
class ScalarEstimator(GFNModule): | ||
r"""Class for estimating scalars such as LogZ. | ||
|
||
Training a GFlowNet requires sometimes requires the estimation of precise scalar | ||
values, such as the partition function of flows on the DAG. This Estimator is | ||
designed for those cases. | ||
|
||
The function approximator used for `module` need not directly output a scalar. If | ||
it does not, `reduction` will be used to aggregate the outputs of the module into | ||
a single scalar. | ||
|
||
Attributes: | ||
preprocessor: Preprocessor object that transforms raw States objects to tensors | ||
that can be used as input to the module. Optional, defaults to | ||
`IdentityPreprocessor`. | ||
module: The module to use. If the module is a Tabular module (from | ||
`gfn.utils.modules`), then the environment preprocessor needs to be an | ||
`EnumPreprocessor`. | ||
preprocessor: Preprocessor from the environment. | ||
_output_dim_is_checked: Flag for tracking whether the output dimensions of | ||
the states (after being preprocessed and transformed by the modules) have | ||
been verified. | ||
_is_backward: Flag for tracking whether this estimator is used for predicting | ||
probability distributions over parents. | ||
reduction_function: String denoting the | ||
""" | ||
|
||
def __init__( | ||
self, | ||
module: nn.Module, | ||
preprocessor: Preprocessor | None = None, | ||
is_backward: bool = False, | ||
reduction: str = "mean", | ||
): | ||
"""Initialize the GFNModule with a scalar output. | ||
Args: | ||
module: The module to use. If the module is a Tabular module (from | ||
`gfn.utils.modules`), then the environment preprocessor needs to be an | ||
`EnumPreprocessor`. | ||
preprocessor: Preprocessor object. | ||
is_backward: Flags estimators of probability distributions over parents. | ||
reduction: str name of the one of the REDUCTION_FXNS keys: {} | ||
""".format( | ||
list(REDUCTION_FXNS.keys()) | ||
) | ||
super().__init__(module, preprocessor, is_backward) | ||
assert reduction in REDUCTION_FXNS, "reduction function not one of {}".format( | ||
REDUCTION_FXNS.keys() | ||
) | ||
self.reduction_fxn = REDUCTION_FXNS[reduction] | ||
|
||
def expected_output_dim(self) -> int: | ||
return 1 | ||
|
||
def forward(self, input: States | torch.Tensor) -> torch.Tensor: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In which case is the input There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I was looking at this and not entirely sure. It might be in the case of conditioning, where we currently don't have any sort of container, conditioning is done with a raw There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, it should be conditioning (e.g., conditional log Z(c)). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well of note, ScalarEstimators are used for more than just logZ, but in this case, I see it like this:
From an optimization POV, sometimes having logZ only be estimated by a single parameter can cause problems (i.e., the gradients push the number around a lot), so using a neural network helps. I agree we could make it clearer though -- I am open to suggestions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
"""Forward pass of the module. | ||
|
||
Args: | ||
input: The input to the module, as states or a tensor. | ||
|
||
Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). | ||
""" | ||
if isinstance(input, States): | ||
input = self.preprocessor(input) | ||
|
||
out = self.module(input) | ||
|
||
# Ensures estimator outputs are always scalar. | ||
if out.shape[-1] != 1: | ||
out = self.reduction_fxn(out, -1) | ||
|
||
if not self._output_dim_is_checked: | ||
self.check_output_dim(out) | ||
self._output_dim_is_checked = True | ||
|
||
return out | ||
|
||
|
||
class DiscretePolicyEstimator(GFNModule): | ||
r"""Container for forward and backward policy estimators for discrete environments. | ||
|
@@ -290,14 +372,57 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: | |
|
||
|
||
class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator): | ||
r"""Class for conditionally estimating scalars such as LogZ. | ||
|
||
Training a GFlowNet requires sometimes requires the estimation of precise scalar | ||
values, such as the partition function of flows on the DAG. In the case of a | ||
conditional GFN, the logZ or logF estimate is also conditional. This Estimator is | ||
designed for those cases. | ||
|
||
The function approximator used for `final_module` need not directly output a scalar. | ||
If it does not, `reduction` will be used to aggregate the outputs of the module into | ||
a single scalar. | ||
|
||
Attributes: | ||
preprocessor: Preprocessor object that transforms raw States objects to tensors | ||
that can be used as input to the module. Optional, defaults to | ||
`IdentityPreprocessor`. | ||
module: The module to use. If the module is a Tabular module (from | ||
`gfn.utils.modules`), then the environment preprocessor needs to be an | ||
`EnumPreprocessor`. | ||
preprocessor: Preprocessor from the environment. | ||
reduction_fxn: the selected torch reduction operation. | ||
_output_dim_is_checked: Flag for tracking whether the output dimensions of | ||
the states (after being preprocessed and transformed by the modules) have | ||
been verified. | ||
_is_backward: Flag for tracking whether this estimator is used for predicting | ||
probability distributions over parents. | ||
reduction_function: String denoting the | ||
""" | ||
|
||
def __init__( | ||
self, | ||
state_module: nn.Module, | ||
conditioning_module: nn.Module, | ||
final_module: nn.Module, | ||
preprocessor: Preprocessor | None = None, | ||
is_backward: bool = False, | ||
reduction: str = "mean", | ||
): | ||
"""Initialize a conditional GFNModule with a scalar output. | ||
Args: | ||
state_module: The module to use for state representations. If the module is | ||
a Tabular module (from `gfn.utils.modules`), then the environment | ||
preprocessor needs to be an `EnumPreprocessor`. | ||
conditioning_module: The module to use for conditioning representations. | ||
final_module: The module to use for computing the final output. | ||
preprocessor: Preprocessor object. | ||
is_backward: Flags estimators of probability distributions over parents. | ||
reduction: str name of the one of the REDUCTION_FXNS keys: {} | ||
""".format( | ||
list(REDUCTION_FXNS.keys()) | ||
) | ||
|
||
super().__init__( | ||
state_module, | ||
conditioning_module, | ||
|
@@ -306,6 +431,10 @@ def __init__( | |
preprocessor=preprocessor, | ||
is_backward=is_backward, | ||
) | ||
assert reduction in REDUCTION_FXNS, "reduction function not one of {}".format( | ||
REDUCTION_FXNS.keys() | ||
) | ||
self.reduction_fxn = REDUCTION_FXNS[reduction] | ||
|
||
def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: | ||
"""Forward pass of the module. | ||
|
@@ -318,6 +447,10 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: | |
""" | ||
out = self._forward_trunk(states, conditioning) | ||
|
||
# Ensures estimator outputs are always scalar. | ||
if out.shape[-1] != 1: | ||
out = self.reduction_fxn(out, -1) | ||
|
||
if not self._output_dim_is_checked: | ||
self.check_output_dim(out) | ||
self._output_dim_is_checked = True | ||
|
@@ -333,13 +466,4 @@ def to_probability_distribution( | |
module_output: torch.Tensor, | ||
**policy_kwargs: Any, | ||
) -> Distribution: | ||
"""Transform the output of the module into a probability distribution. | ||
|
||
Args: | ||
states: The states to use. | ||
module_output: The output of the module as a tensor of shape (*batch_shape, output_dim). | ||
**policy_kwargs: Keyword arguments to modify the distribution. | ||
|
||
Returns a distribution object. | ||
""" | ||
raise NotImplementedError |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,7 +35,11 @@ def sample_actions( | |
save_estimator_outputs: bool = False, | ||
save_logprobs: bool = True, | ||
**policy_kwargs: Any, | ||
) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None,]: | ||
) -> Tuple[ | ||
Actions, | ||
torch.Tensor | None, | ||
torch.Tensor | None, | ||
]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removing the last |
||
"""Samples actions from the given states. | ||
|
||
Args: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that
logZ
for unconditional TB is usually modeled with a single learnable parameter (nn.Parameter
).Should we consider modifying
ScalarEstimator
to support this kind of behavior?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The GFNs themselves support this directly (you do not need to pass an estimator at all, instead you just pass a float for
Z
).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment is because of
such as LogZ
in the docstring!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not entirely sure what would be most clear here but I'm open to suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just state flow functions of DB/SubTB??