Skip to content

Commit

Permalink
Scalar estimators allow for the reduction over many output values (i.…
Browse files Browse the repository at this point in the history
…e., the output of the nn.Module does not need to be a scalar, because the Estimator will apply a reduction to the final output if required).
  • Loading branch information
josephdviviano committed Nov 13, 2024
1 parent 4a9f112 commit 1c4ec37
Showing 1 changed file with 99 additions and 9 deletions.
108 changes: 99 additions & 9 deletions src/gfn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,71 @@ def to_probability_distribution(


class ScalarEstimator(GFNModule):
r"""Class for estimating scalars such as LogZ.
Training a GFlowNet requires sometimes requires the estimation of precise scalar
values, such as the partition function of flows on the DAG. This Estimator is
designed for those cases.
The function approximator used for `module` need not directly output a scalar. If
it does not, `reduction` will be used to aggregate the outputs of the module into
a single scalar.
Attributes:
preprocessor: Preprocessor object that transforms raw States objects to tensors
that can be used as input to the module. Optional, defaults to
`IdentityPreprocessor`.
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor from the environment.
_output_dim_is_checked: Flag for tracking whether the output dimenions of
the states (after being preprocessed and transformed by the modules) have
been verified.
"""

def __init__(
self,
module: nn.Module,
preprocessor: Preprocessor | None = None,
is_backward: bool = False,
reduction: str = "mean",
):
super().__init__(module, preprocessor, is_backward)
reduction_fxns = {
"mean": torch.mean,
"sum": torch.sum,
"prod": torch.prod,
}
assert reduction in reduction_fxns
self.reduction_fxn = reduction_fxns[reduction]

def expected_output_dim(self) -> int:
return 1

def forward(self, input: States | torch.Tensor) -> torch.Tensor:
"""Forward pass of the module.
Args:
input: The input to the module, as states or a tensor.
Returns the output of the module, as a tensor of shape (*batch_shape, output_dim).
"""
if isinstance(input, States):
input = self.preprocessor(input)

out = self.module(input)

# Ensures estimator outputs are always scalar.
if out.shape[-1] != 1:
out = self.reduction_fxn(out, -1)

if not self._output_dim_is_checked:
self.check_output_dim(out)
self._output_dim_is_checked = True

return out


class DiscretePolicyEstimator(GFNModule):
r"""Container for forward and backward policy estimators for discrete environments.
Expand Down Expand Up @@ -290,13 +352,39 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:


class ConditionalScalarEstimator(ConditionalDiscretePolicyEstimator):
r"""Class for conditionally estimating scalars such as LogZ.
Training a GFlowNet requires sometimes requires the estimation of precise scalar
values, such as the partition function of flows on the DAG. In the case of a
conditional GFN, the logZ or logF estimate is also conditional. This Estimator is
designed for those cases.
The function approximator used for `module` need not directly output a scalar. If
it does not, `reduction` will be used to aggregate the outputs of the module into
a single scalar.
Attributes:
preprocessor: Preprocessor object that transforms raw States objects to tensors
that can be used as input to the module. Optional, defaults to
`IdentityPreprocessor`.
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor from the environment.
reduction_fxn: the selected torch reduction operation.
_output_dim_is_checked: Flag for tracking whether the output dimensions of
the states (after being preprocessed and transformed by the modules) have
been verified.
"""

def __init__(
self,
state_module: nn.Module,
conditioning_module: nn.Module,
final_module: nn.Module,
preprocessor: Preprocessor | None = None,
is_backward: bool = False,
reduction: str = "mean",
):
super().__init__(
state_module,
Expand All @@ -306,6 +394,13 @@ def __init__(
preprocessor=preprocessor,
is_backward=is_backward,
)
reduction_fxns = {
"mean": torch.mean,
"sum": torch.sum,
"prod": torch.prod,
}
assert reduction in reduction_fxns
self.reduction_fxn = reduction_fxns[reduction]

def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:
"""Forward pass of the module.
Expand All @@ -318,6 +413,10 @@ def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor:
"""
out = self._forward_trunk(states, conditioning)

# Ensures estimator outputs are always scalar.
if out.shape[-1] != 1:
out = self.reduction_fxn(out, -1)

if not self._output_dim_is_checked:
self.check_output_dim(out)
self._output_dim_is_checked = True
Expand All @@ -333,13 +432,4 @@ def to_probability_distribution(
module_output: torch.Tensor,
**policy_kwargs: Any,
) -> Distribution:
"""Transform the output of the module into a probability distribution.
Args:
states: The states to use.
module_output: The output of the module as a tensor of shape (*batch_shape, output_dim).
**policy_kwargs: Keyword arguments to modify the distribution.
Returns a distribution object.
"""
raise NotImplementedError

0 comments on commit 1c4ec37

Please sign in to comment.