Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scalar estimators allow for the reduction over many output values (i.… #215

Merged
merged 7 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/gfn/gflownet/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ def loss(
)
return fm_loss + self.alpha * rm_loss

def to_training_samples(
self, trajectories: Trajectories
) -> Union[
def to_training_samples(self, trajectories: Trajectories) -> Union[
Tuple[DiscreteStates, DiscreteStates, torch.Tensor, torch.Tensor],
Tuple[DiscreteStates, DiscreteStates, None, None],
Tuple[States, States, torch.Tensor, torch.Tensor],
Expand Down
146 changes: 135 additions & 11 deletions src/gfn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
from gfn.utils.distributions import UnsqueezedCategorical


REDUCTION_FXNS = {
"mean": torch.mean,
"sum": torch.sum,
"prod": torch.prod,
}


class GFNModule(ABC, nn.Module):
r"""Base class for modules mapping states distributions.

Expand Down Expand Up @@ -41,9 +48,11 @@ class GFNModule(ABC, nn.Module):
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor from the environment.
_output_dim_is_checked: Flag for tracking whether the output dimenions of
_output_dim_is_checked: Flag for tracking whether the output dimensions of
the states (after being preprocessed and transformed by the modules) have
been verified.
_is_backward: Flag for tracking whether this estimator is used for predicting
probability distributions over parents.
"""

def __init__(
Expand All @@ -52,7 +61,7 @@ def __init__(
preprocessor: Preprocessor | None = None,
is_backward: bool = False,
) -> None:
"""Initalize the FunctionEstimator with an environment and a module.
"""Initialize the GFNModule with nn.Module and a preprocessor.
Args:
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
Expand Down Expand Up @@ -134,9 +143,82 @@ def to_probability_distribution(


class ScalarEstimator(GFNModule):
r"""Class for estimating scalars such as LogZ.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that logZ for unconditional TB is usually modeled with a single learnable parameter (nn.Parameter).
Should we consider modifying ScalarEstimator to support this kind of behavior?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GFNs themselves support this directly (you do not need to pass an estimator at all, instead you just pass a float for Z).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is because of such as LogZ in the docstring!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure what would be most clear here but I'm open to suggestions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just state flow functions of DB/SubTB??


Training a GFlowNet requires sometimes requires the estimation of precise scalar
values, such as the partition function of flows on the DAG. This Estimator is
designed for those cases.

The function approximator used for `module` need not directly output a scalar. If
it does not, `reduction` will be used to aggregate the outputs of the module into
a single scalar.

Attributes:
preprocessor: Preprocessor object that transforms raw States objects to tensors
that can be used as input to the module. Optional, defaults to
`IdentityPreprocessor`.
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor from the environment.
_output_dim_is_checked: Flag for tracking whether the output dimensions of
the states (after being preprocessed and transformed by the modules) have
been verified.
_is_backward: Flag for tracking whether this estimator is used for predicting
probability distributions over parents.
reduction_function: String denoting the
"""

def __init__(
self,
module: nn.Module,
preprocessor: Preprocessor | None = None,
is_backward: bool = False,
reduction: str = "mean",
):
"""Initialize the GFNModule with a scalar output.
Args:
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor object.
is_backward: Flags estimators of probability distributions over parents.
reduction: str name of the one of the REDUCTION_FXNS keys: {}
""".format(
list(REDUCTION_FXNS.keys())
)
super().__init__(module, preprocessor, is_backward)
assert reduction in REDUCTION_FXNS, "reduction function not one of {}".format(
REDUCTION_FXNS.keys()
)
self.reduction_fxn = REDUCTION_FXNS[reduction]

def expected_output_dim(self) -> int:
return 1

def forward(self, input: States | torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which case is the input torch.Tensor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I was looking at this and not entirely sure. It might be in the case of conditioning, where we currently don't have any sort of container, conditioning is done with a raw Tensor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it should be conditioning (e.g., conditional log Z(c)).
However, it might be a bit confusing whether to use ConditionalScalarEstimator or ScalarEstimator to model log Z(c).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well of note, ScalarEstimators are used for more than just logZ, but in this case, I see it like this:

  • LogZ can be a single parameter.
  • LogZ can be estimated using a neural network - in this case, the output of the network can actually be multiple items that are averaged over.
  • LogZ can be conditionally estimated using a neural network - in this case, the output of the network can actually be multiple items that are averaged over.

From an optimization POV, sometimes having logZ only be estimated by a single parameter can cause problems (i.e., the gradients push the number around a lot), so using a neural network helps.

I agree we could make it clearer though -- I am open to suggestions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConditionalScalarEstimator is used to take in both the State and the Conditioning, i.e., it's a two-headed estimator. I think this is the normal conditioning case.

"""Forward pass of the module.

Args:
input: The input to the module, as states or a tensor.

Returns the output of the module, as a tensor of shape (*batch_shape, output_dim).
"""
if isinstance(input, States):
input = self.preprocessor(input)

out = self.module(input)

# Ensures estimator outputs are always scalar.
if out.shape[-1] != 1:
out = self.reduction_fxn(out, -1)

if not self._output_dim_is_checked:
self.check_output_dim(out)
self._output_dim_is_checked = True

return out


class DiscretePolicyEstimator(GFNModule):
r"""Container for forward and backward policy estimators for discrete environments.
Expand Down Expand Up @@ -290,14 +372,57 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:


class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator):
r"""Class for conditionally estimating scalars such as LogZ.

Training a GFlowNet requires sometimes requires the estimation of precise scalar
values, such as the partition function of flows on the DAG. In the case of a
conditional GFN, the logZ or logF estimate is also conditional. This Estimator is
designed for those cases.

The function approximator used for `final_module` need not directly output a scalar.
If it does not, `reduction` will be used to aggregate the outputs of the module into
a single scalar.

Attributes:
preprocessor: Preprocessor object that transforms raw States objects to tensors
that can be used as input to the module. Optional, defaults to
`IdentityPreprocessor`.
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor from the environment.
reduction_fxn: the selected torch reduction operation.
_output_dim_is_checked: Flag for tracking whether the output dimensions of
the states (after being preprocessed and transformed by the modules) have
been verified.
_is_backward: Flag for tracking whether this estimator is used for predicting
probability distributions over parents.
reduction_function: String denoting the
"""

def __init__(
self,
state_module: nn.Module,
conditioning_module: nn.Module,
final_module: nn.Module,
preprocessor: Preprocessor | None = None,
is_backward: bool = False,
reduction: str = "mean",
):
"""Initialize a conditional GFNModule with a scalar output.
Args:
state_module: The module to use for state representations. If the module is
a Tabular module (from `gfn.utils.modules`), then the environment
preprocessor needs to be an `EnumPreprocessor`.
conditioning_module: The module to use for conditioning representations.
final_module: The module to use for computing the final output.
preprocessor: Preprocessor object.
is_backward: Flags estimators of probability distributions over parents.
reduction: str name of the one of the REDUCTION_FXNS keys: {}
""".format(
list(REDUCTION_FXNS.keys())
)

super().__init__(
state_module,
conditioning_module,
Expand All @@ -306,6 +431,10 @@ def __init__(
preprocessor=preprocessor,
is_backward=is_backward,
)
assert reduction in REDUCTION_FXNS, "reduction function not one of {}".format(
REDUCTION_FXNS.keys()
)
self.reduction_fxn = REDUCTION_FXNS[reduction]

def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:
"""Forward pass of the module.
Expand All @@ -318,6 +447,10 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:
"""
out = self._forward_trunk(states, conditioning)

# Ensures estimator outputs are always scalar.
if out.shape[-1] != 1:
out = self.reduction_fxn(out, -1)

if not self._output_dim_is_checked:
self.check_output_dim(out)
self._output_dim_is_checked = True
Expand All @@ -333,13 +466,4 @@ def to_probability_distribution(
module_output: torch.Tensor,
**policy_kwargs: Any,
) -> Distribution:
"""Transform the output of the module into a probability distribution.

Args:
states: The states to use.
module_output: The output of the module as a tensor of shape (*batch_shape, output_dim).
**policy_kwargs: Keyword arguments to modify the distribution.

Returns a distribution object.
"""
raise NotImplementedError
6 changes: 5 additions & 1 deletion src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def sample_actions(
save_estimator_outputs: bool = False,
save_logprobs: bool = True,
**policy_kwargs: Any,
) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None,]:
) -> Tuple[
Actions,
torch.Tensor | None,
torch.Tensor | None,
]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the last , will make this one line.

"""Samples actions from the given states.

Args:
Expand Down
Loading