Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

conditional gfn #188

Merged
merged 33 commits into from
Oct 24, 2024
Merged
Changes from 1 commit
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6e8dc4d
example of conditional GFN computation with TB only (for now)
josephdviviano Sep 25, 2024
39fb5ee
should be no change
josephdviviano Sep 25, 2024
2bc2263
Trajectories objects now have an optional .conditonal field which opt…
josephdviviano Sep 25, 2024
99afaf3
small changes to logz paramater handling, optionally incorporate cond…
josephdviviano Sep 25, 2024
e6d25a0
logZ is optionally computed using a conditioning vector
josephdviviano Sep 25, 2024
2c72bf9
NeuralNets now have input/output dims
josephdviviano Sep 25, 2024
580c455
added a ConditionalDiscretePolicyEstimator, and the forward of GFNMod…
josephdviviano Sep 25, 2024
a74872f
added conditioning to sampler, which will save the tensor as an attri…
josephdviviano Sep 25, 2024
056d935
black
josephdviviano Sep 25, 2024
96b725c
API changes adapted
josephdviviano Oct 1, 2024
5cd32a7
added conditioning to all gflownets
josephdviviano Oct 1, 2024
877c4a0
both trajectories and transitions can now store a conditioning tensor
josephdviviano Oct 1, 2024
279a313
input_dim setting is now private
josephdviviano Oct 1, 2024
65135c1
added exception handling for all estimator calls potentially involvin…
josephdviviano Oct 1, 2024
b4c418c
API change -- n vs. n_trajectories
josephdviviano Oct 1, 2024
738b062
change test_box target value
josephdviviano Oct 1, 2024
4434e5f
API changes
josephdviviano Oct 1, 2024
851e03e
hacky fix for problematic test (added TODO)
josephdviviano Oct 1, 2024
5152295
working examples for all 4 major losses
josephdviviano Oct 4, 2024
1d64b55
added conditioning indexing for correct broadcasting
josephdviviano Oct 4, 2024
348ee82
added a ConditionalScalarEstimator which subclasses ConditionalDiscre…
josephdviviano Oct 4, 2024
9120afe
added modified DB example
josephdviviano Oct 4, 2024
f59f4de
conditioning added to modified db example
josephdviviano Oct 4, 2024
c5ef7ea
black
josephdviviano Oct 4, 2024
d67dfd5
reorganized keyword arguments and fixed some type errors (not all)
josephdviviano Oct 9, 2024
d56a798
reorganized keyword arguments and fixed some type errors (not all)
josephdviviano Oct 9, 2024
db8844c
added typing and a ConditionalScalarEstimator
josephdviviano Oct 9, 2024
e03c03a
added typing
josephdviviano Oct 9, 2024
6b47e06
typing
josephdviviano Oct 9, 2024
988faf0
typing
josephdviviano Oct 9, 2024
f2bbce3
added kwargs
josephdviviano Oct 9, 2024
eb13a2d
renamed torso to trunk
josephdviviano Oct 24, 2024
fd3d9dc
renamed torso to trunk
josephdviviano Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions src/gfn/gflownet/flow_matching.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Tuple, Any, Union

import torch
from torchtyping import TensorType as TT
Expand All @@ -8,7 +8,7 @@
from gfn.gflownet.base import GFlowNet
from gfn.modules import DiscretePolicyEstimator, ConditionalDiscretePolicyEstimator
from gfn.samplers import Sampler
from gfn.states import DiscreteStates
from gfn.states import DiscreteStates, States
from gfn.utils.handlers import (
no_conditioning_exception_handler,
has_conditioning_exception_handler,
Expand Down Expand Up @@ -45,9 +45,10 @@ def sample_trajectories(
self,
env: Env,
n: int,
conditioning: torch.Tensor | None = None,
save_logprobs: bool = True,
save_estimator_outputs: bool = False,
**policy_kwargs: Optional[dict],
**policy_kwargs: Any,
) -> Trajectories:
"""Sample trajectory with optional kwargs controling the policy."""
if not env.is_discrete:
Expand All @@ -58,8 +59,9 @@ def sample_trajectories(
trajectories = sampler.sample_trajectories(
env,
n=n,
save_estimator_outputs=save_estimator_outputs,
conditioning=conditioning,
save_logprobs=save_logprobs,
save_estimator_outputs=save_estimator_outputs,
**policy_kwargs,
)
return trajectories
Expand Down Expand Up @@ -176,7 +178,12 @@ def reward_matching_loss(
return (terminating_log_edge_flows - log_rewards).pow(2).mean()

def loss(
self, env: Env, states_tuple: Tuple[DiscreteStates, DiscreteStates]
self,
env: Env,
states_tuple: Union[
Tuple[DiscreteStates, DiscreteStates, torch.Tensor, torch.Tensor],
Tuple[DiscreteStates, DiscreteStates, None, None],
],
) -> TT[0, float]:
"""Given a batch of non-terminal and terminal states, compute a loss.

Expand All @@ -198,8 +205,11 @@ def loss(
)
return fm_loss + self.alpha * rm_loss

def to_training_samples(
self, trajectories: Trajectories
) -> tuple[DiscreteStates, DiscreteStates, torch.Tensor]:
def to_training_samples(self, trajectories: Trajectories) -> Union[
Tuple[DiscreteStates, DiscreteStates, torch.Tensor, torch.Tensor],
Tuple[DiscreteStates, DiscreteStates, None, None],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤯

Tuple[States, States, torch.Tensor, torch.Tensor],
Tuple[States, States, None, None],
]:
"""Converts a batch of trajectories into a batch of training samples."""
return trajectories.to_non_initial_intermediary_and_terminating_states()