Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scalar estimators allow for the reduction over many output values (i.… #215

Merged
merged 7 commits into from
Nov 15, 2024

Conversation

josephdviviano
Copy link
Collaborator

…e., the output of the nn.Module does not need to be a scalar, because the Estimator will apply a reduction to the final output if required).

I am not in love with the current code organization of modules.py -- there is some duplication, but I am thinking that a bigger refactoring effort might be en route and perhaps we should wait to optimize. Open to feedback on this!

…e., the output of the nn.Module does not need to be a scalar, because the Estimator will apply a reduction to the final output if required).
@josephdviviano josephdviviano self-assigned this Nov 13, 2024
Copy link
Collaborator

@younik younik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! I added a comment for possible refactoring.

Comment on lines 397 to 401
reduction_fxns = {
"mean": torch.mean,
"sum": torch.sum,
"prod": torch.prod,
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here you can use the global constant, if you follow the previous comment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks - I've done this

Comment on lines 168 to 172
reduction_fxns = {
"mean": torch.mean,
"sum": torch.sum,
"prod": torch.prod,
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is constant, for convention it should go outside with upper case name

Copy link
Collaborator

@hyeok9855 hyeok9855 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! I left a few questions and comments below.

Comment on lines 147 to 157
Attributes:
preprocessor: Preprocessor object that transforms raw States objects to tensors
that can be used as input to the module. Optional, defaults to
`IdentityPreprocessor`.
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor from the environment.
_output_dim_is_checked: Flag for tracking whether the output dimenions of
the states (after being preprocessed and transformed by the modules) have
been verified.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be updated accordingly, e.g., add is_backward and reduction and remove _output_dim_is_checked.

Comment on lines 366 to 378
Attributes:
preprocessor: Preprocessor object that transforms raw States objects to tensors
that can be used as input to the module. Optional, defaults to
`IdentityPreprocessor`.
module: The module to use. If the module is a Tabular module (from
`gfn.utils.modules`), then the environment preprocessor needs to be an
`EnumPreprocessor`.
preprocessor: Preprocessor from the environment.
reduction_fxn: the selected torch reduction operation.
_output_dim_is_checked: Flag for tracking whether the output dimensions of
the states (after being preprocessed and transformed by the modules) have
been verified.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DITTO

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks :)

@@ -134,9 +134,71 @@ def to_probability_distribution(


class ScalarEstimator(GFNModule):
r"""Class for estimating scalars such as LogZ.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that logZ for unconditional TB is usually modeled with a single learnable parameter (nn.Parameter).
Should we consider modifying ScalarEstimator to support this kind of behavior?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GFNs themselves support this directly (you do not need to pass an estimator at all, instead you just pass a float for Z).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is because of such as LogZ in the docstring!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure what would be most clear here but I'm open to suggestions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just state flow functions of DB/SubTB??

Comment on lines 38 to 42
) -> Tuple[
Actions,
torch.Tensor | None,
torch.Tensor | None,
]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the last , will make this one line.

def expected_output_dim(self) -> int:
return 1

def forward(self, input: States | torch.Tensor) -> torch.Tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In which case is the input torch.Tensor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I was looking at this and not entirely sure. It might be in the case of conditioning, where we currently don't have any sort of container, conditioning is done with a raw Tensor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it should be conditioning (e.g., conditional log Z(c)).
However, it might be a bit confusing whether to use ConditionalScalarEstimator or ScalarEstimator to model log Z(c).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well of note, ScalarEstimators are used for more than just logZ, but in this case, I see it like this:

  • LogZ can be a single parameter.
  • LogZ can be estimated using a neural network - in this case, the output of the network can actually be multiple items that are averaged over.
  • LogZ can be conditionally estimated using a neural network - in this case, the output of the network can actually be multiple items that are averaged over.

From an optimization POV, sometimes having logZ only be estimated by a single parameter can cause problems (i.e., the gradients push the number around a lot), so using a neural network helps.

I agree we could make it clearer though -- I am open to suggestions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConditionalScalarEstimator is used to take in both the State and the Conditioning, i.e., it's a two-headed estimator. I think this is the normal conditioning case.

Copy link
Collaborator

@hyeok9855 hyeok9855 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved!

@josephdviviano josephdviviano merged commit 8259a21 into master Nov 15, 2024
4 checks passed
@josephdviviano josephdviviano deleted the scalar_estimation_from_vectors branch November 15, 2024 00:52
@josephdviviano josephdviviano restored the scalar_estimation_from_vectors branch November 15, 2024 00:53
@josephdviviano josephdviviano deleted the scalar_estimation_from_vectors branch November 15, 2024 00:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants