Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

conditional gfn #188

Merged
merged 33 commits into from
Oct 24, 2024
Merged

conditional gfn #188

merged 33 commits into from
Oct 24, 2024

Conversation

josephdviviano
Copy link
Collaborator

@josephdviviano josephdviviano commented Sep 25, 2024

Supports conditioning on a tensor of shape=[n_trajectories, n_cond_dims]. This is passed by the user during a call to the sampler.

Implemented for all GFlowNets. Note that the current version expects a particular kind of estimator. I can imagine this will lead to future changes - e.g., we should have some Estimators which expect huggingface models, so we can use them to produce conditioning vectors / to initialize the policy (this will obviously be a future PR).

Note that the conditioning is useless in my example, we should have a better use-case envisioned for the demo. The demo currently is not complete for all GFlowNet types.

@josephdviviano josephdviviano added the enhancement New feature or request label Sep 25, 2024
@josephdviviano josephdviviano self-assigned this Sep 25, 2024
@josephdviviano
Copy link
Collaborator Author

Don't worry about the tests - they should be easy to fix.

I can make the chances for DB, Sub-TB, and FM pretty easily if we agree this is a good approach, before a proper review.


or

$s \mapsto (P_B(s' \mid s, c))_{s' \in Parents(s)}$.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be worth mentioning that this is a s very specific conditioning use-case, where the condition is encoded separately, and embeddings are concatenated.

I don't think we can do a generic one, but this should be enough as an example !

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What other conditioning approaches would be worth including? Cross attention?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general I would think the conditioning should be embedded / encoded separately --- or would the conditioning just need to be concatenated to the state before input? I could add support for that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there is an exhaustive list of ways we can process the condition. What you have is great as an example. I suggest you just add a comment or doc that the user might want to write their own module

@@ -68,7 +67,28 @@ def sample_actions(
the sampled actions under the probability distribution of the given
states.
"""
estimator_output = self.estimator(states)
# TODO: Should estimators instead ignore None for the conditioning vector?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldn't it be cleaner with fewer if else blocks ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes there's a bit of cruft with all the if-else blocks, but as it stands an estimator can either accept one or two arguments and I think it's good if it fails noisily... what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok ! makes sense.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added these exception_handlers to reduce the cruft.

@saleml
Copy link
Collaborator

saleml commented Sep 25, 2024

LGTM! Looking forward to test this feature

@josephdviviano josephdviviano marked this pull request as ready for review October 1, 2024 16:34
Copy link
Collaborator

@saleml saleml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy to see this being added to the library. Great work! Great code design, and thanks for factorizing a few other things, including the context managers / error handlers.

I left a few comments and a suggestion for the script.

@@ -32,41 +35,41 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]):
def sample_trajectories(
self,
env: Env,
n_samples: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like you're handling the conditioning input to this function as a kwarg, whereas sampler's sample_trajectories have an explicit conditioning input. I'm wondering if you have a particular reason for this choice

Copy link
Collaborator Author

@josephdviviano josephdviviano Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maybe all functions should use an explicit conditioning kwarg, what do you think? I can make those changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it would be cleaner

Copy link
Collaborator Author

@josephdviviano josephdviviano Oct 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be done now, let me know if i missed something.

conditioning = torch.rand((batch_size, 1))
conditioning = (conditioning > 0.5).to(torch.float) # Randomly 1 and zero.

trajectories = gflownet.sample_trajectories(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2024-10-05 at 4 36 56 PM

pylance is not happy with these two variables. I'm wondering if this is due to using **kwargs (see my comment in the base.py file). If so, it would be nice to decide whether we should not care hereafter about pylance and co, or care, in which case, no more kwargs in the code base

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I think I fixed this but moving to **kwargs: Any but we have a multitude of other harder to handle pylance issues that I'm not sure what to do about and warrants a discussion bigger than the scope of this PR, I think.

print("+ Training Conditional {}!".format(type(gflownet)))
for i in (pbar := tqdm(range(n_iterations))):
conditioning = torch.rand((batch_size, 1))
conditioning = (conditioning > 0.5).to(torch.float) # Randomly 1 and zero.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, the conditioning doesn't change anything in this example. While this file is a great way to show how one can code their conditional gflownet, what do you think of slightly altering the setting here, e.g. by making the environment conditional (e.g. hide one of the 4 modes if conditioning=1), and then, post-training, have some validation where we compare the resulting pair of distributions to the pair of target distributions ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, you're right. can we save this for a follow up PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I filed the issue here. if you agree I'd like to do this work separately.

#190

Comment on lines 219 to 232
gflownet = build_tb_gflownet(environment)
train(environment, gflownet)

gflownet = build_db_gflownet(environment)
train(environment, gflownet)

gflownet = build_db_mod_gflownet(environment)
train(environment, gflownet)

gflownet = build_subTB_gflownet(environment)
train(environment, gflownet)

gflownet = build_fm_gflownet(environment)
train(environment, gflownet)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

argparse this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fini

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pleasantly surprised no change is needed for the LogPartitionVarianceLoss. Right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need the conditioning information here, and I agree it's nice that the code naturally reflected that. Please correct me if I misunderstand this loss.

) -> tuple[DiscreteStates, DiscreteStates, torch.Tensor]:
def to_training_samples(self, trajectories: Trajectories) -> Union[
Tuple[DiscreteStates, DiscreteStates, torch.Tensor, torch.Tensor],
Tuple[DiscreteStates, DiscreteStates, None, None],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤯

@@ -240,13 +240,20 @@ def __init__(
self.conditioning_module = conditioning_module
self.final_module = final_module

def forward(
self, states: States, conditioning: torch.tensor
def _forward_trunk(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is what you call trunk the same thing I called torso before ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah -- let me unify the naming

@saleml
Copy link
Collaborator

saleml commented Oct 17, 2024

LGTM!

Thanks for the PR

@josephdviviano josephdviviano merged commit d2d959e into master Oct 24, 2024
3 checks passed
@josephdviviano josephdviviano deleted the conditional_gfn branch October 24, 2024 19:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants