Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

conditional gfn #188

Merged
merged 33 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6e8dc4d
example of conditional GFN computation with TB only (for now)
josephdviviano Sep 25, 2024
39fb5ee
should be no change
josephdviviano Sep 25, 2024
2bc2263
Trajectories objects now have an optional .conditonal field which opt…
josephdviviano Sep 25, 2024
99afaf3
small changes to logz paramater handling, optionally incorporate cond…
josephdviviano Sep 25, 2024
e6d25a0
logZ is optionally computed using a conditioning vector
josephdviviano Sep 25, 2024
2c72bf9
NeuralNets now have input/output dims
josephdviviano Sep 25, 2024
580c455
added a ConditionalDiscretePolicyEstimator, and the forward of GFNMod…
josephdviviano Sep 25, 2024
a74872f
added conditioning to sampler, which will save the tensor as an attri…
josephdviviano Sep 25, 2024
056d935
black
josephdviviano Sep 25, 2024
96b725c
API changes adapted
josephdviviano Oct 1, 2024
5cd32a7
added conditioning to all gflownets
josephdviviano Oct 1, 2024
877c4a0
both trajectories and transitions can now store a conditioning tensor
josephdviviano Oct 1, 2024
279a313
input_dim setting is now private
josephdviviano Oct 1, 2024
65135c1
added exception handling for all estimator calls potentially involvin…
josephdviviano Oct 1, 2024
b4c418c
API change -- n vs. n_trajectories
josephdviviano Oct 1, 2024
738b062
change test_box target value
josephdviviano Oct 1, 2024
4434e5f
API changes
josephdviviano Oct 1, 2024
851e03e
hacky fix for problematic test (added TODO)
josephdviviano Oct 1, 2024
5152295
working examples for all 4 major losses
josephdviviano Oct 4, 2024
1d64b55
added conditioning indexing for correct broadcasting
josephdviviano Oct 4, 2024
348ee82
added a ConditionalScalarEstimator which subclasses ConditionalDiscre…
josephdviviano Oct 4, 2024
9120afe
added modified DB example
josephdviviano Oct 4, 2024
f59f4de
conditioning added to modified db example
josephdviviano Oct 4, 2024
c5ef7ea
black
josephdviviano Oct 4, 2024
d67dfd5
reorganized keyword arguments and fixed some type errors (not all)
josephdviviano Oct 9, 2024
d56a798
reorganized keyword arguments and fixed some type errors (not all)
josephdviviano Oct 9, 2024
db8844c
added typing and a ConditionalScalarEstimator
josephdviviano Oct 9, 2024
e03c03a
added typing
josephdviviano Oct 9, 2024
6b47e06
typing
josephdviviano Oct 9, 2024
988faf0
typing
josephdviviano Oct 9, 2024
f2bbce3
added kwargs
josephdviviano Oct 9, 2024
eb13a2d
renamed torso to trunk
josephdviviano Oct 24, 2024
fd3d9dc
renamed torso to trunk
josephdviviano Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ module_PF = NeuralNet(
module_PB = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
torso=module_PF.torso # We share all the parameters of P_F and P_B, except for the last layer
trunk=module_PF.trunk # We share all the parameters of P_F and P_B, except for the last layer
)

# 3 - We define the estimators.
Expand Down Expand Up @@ -136,7 +136,7 @@ module_PF = NeuralNet(
module_PB = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
torso=module_PF.torso # We share all the parameters of P_F and P_B, except for the last layer
trunk=module_PF.trunk # We share all the parameters of P_F and P_B, except for the last layer
)
module_logF = NeuralNet(
input_dim=env.preprocessor.output_dim,
Expand Down
41 changes: 37 additions & 4 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Sequence, Union, Tuple


if TYPE_CHECKING:
from gfn.actions import Actions
from gfn.env import Env
from gfn.states import States
from gfn.states import States, DiscreteStates

import numpy as np
import torch
Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
self,
env: Env,
states: States | None = None,
conditioning: torch.Tensor | None = None,
actions: Actions | None = None,
when_is_done: TT["n_trajectories", torch.long] | None = None,
is_backward: bool = False,
Expand All @@ -76,6 +78,7 @@ def __init__(
is used to compute the rewards, at each call of self.log_rewards
"""
self.env = env
self.conditioning = conditioning
self.is_backward = is_backward
self.states = (
states if states is not None else env.states_from_batch_shape((0, 0))
Expand Down Expand Up @@ -315,6 +318,15 @@ def extend(self, other: Trajectories) -> None:

def to_transitions(self) -> Transitions:
"""Returns a `Transitions` object from the trajectories."""
if self.conditioning is not None:
traj_len = self.actions.batch_shape[0]
expand_dims = (traj_len,) + tuple(self.conditioning.shape)
conditioning = self.conditioning.unsqueeze(0).expand(expand_dims)[
~self.actions.is_dummy
]
else:
conditioning = None

states = self.states[:-1][~self.actions.is_dummy]
next_states = self.states[1:][~self.actions.is_dummy]
actions = self.actions[~self.actions.is_dummy]
Expand Down Expand Up @@ -348,6 +360,7 @@ def to_transitions(self) -> Transitions:
return Transitions(
env=self.env,
states=states,
conditioning=conditioning,
actions=actions,
is_done=is_done,
next_states=next_states,
Expand All @@ -363,7 +376,10 @@ def to_states(self) -> States:

def to_non_initial_intermediary_and_terminating_states(
self,
) -> tuple[States, States]:
) -> Union[
Tuple[States, States, torch.Tensor, torch.Tensor],
Tuple[States, States, None, None],
]:
"""Returns all intermediate and terminating `States` from the trajectories.

This is useful for the flow matching loss, that requires its inputs to be distinguished.
Expand All @@ -373,10 +389,27 @@ def to_non_initial_intermediary_and_terminating_states(
are not s0.
"""
states = self.states

if self.conditioning is not None:
traj_len = self.states.batch_shape[0]
expand_dims = (traj_len,) + tuple(self.conditioning.shape)
intermediary_conditioning = self.conditioning.unsqueeze(0).expand(
expand_dims
)[~states.is_sink_state & ~states.is_initial_state]
conditioning = self.conditioning # n_final_states == n_trajectories.
else:
intermediary_conditioning = None
conditioning = None

intermediary_states = states[~states.is_sink_state & ~states.is_initial_state]
terminating_states = self.last_states
terminating_states.log_rewards = self.log_rewards
return intermediary_states, terminating_states
return (
intermediary_states,
terminating_states,
intermediary_conditioning,
conditioning,
)


def pad_dim0_to_target(a: torch.Tensor, target_dim0: int) -> torch.Tensor:
Expand Down
2 changes: 2 additions & 0 deletions src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
self,
env: Env,
states: States | None = None,
conditioning: torch.Tensor | None = None,
actions: Actions | None = None,
is_done: TT["n_transitions", torch.bool] | None = None,
next_states: States | None = None,
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
`batch_shapes`.
"""
self.env = env
self.conditioning = conditioning
self.is_backward = is_backward
self.states = (
states
Expand Down
66 changes: 50 additions & 16 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import math
from abc import ABC, abstractmethod
from typing import Generic, Tuple, TypeVar, Union
from typing import Generic, Tuple, TypeVar, Union, Any

import torch
import torch.nn as nn
from torch import Tensor
from torchtyping import TensorType as TT

from gfn.containers import Trajectories
Expand All @@ -14,6 +13,10 @@
from gfn.samplers import Sampler
from gfn.states import States
from gfn.utils.common import has_log_probs
from gfn.utils.handlers import (
has_conditioning_exception_handler,
no_conditioning_exception_handler,
)

TrainingSampleType = TypeVar(
"TrainingSampleType", bound=Union[Container, tuple[States, ...]]
Expand All @@ -32,48 +35,48 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]):
def sample_trajectories(
self,
env: Env,
n_samples: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like you're handling the conditioning input to this function as a kwarg, whereas sampler's sample_trajectories have an explicit conditioning input. I'm wondering if you have a particular reason for this choice

Copy link
Collaborator Author

@josephdviviano josephdviviano Oct 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think maybe all functions should use an explicit conditioning kwarg, what do you think? I can make those changes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it would be cleaner

Copy link
Collaborator Author

@josephdviviano josephdviviano Oct 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be done now, let me know if i missed something.

n: int,
save_logprobs: bool = True,
save_estimator_outputs: bool = False,
) -> Trajectories:
"""Sample a specific number of complete trajectories.

Args:
env: the environment to sample trajectories from.
n_samples: number of trajectories to be sampled.
n: number of trajectories to be sampled.
save_logprobs: whether to save the logprobs of the actions - useful for on-policy learning.
save_estimator_outputs: whether to save the estimator outputs - useful for off-policy learning
with tempered policy
Returns:
Trajectories: sampled trajectories object.
"""

def sample_terminating_states(self, env: Env, n_samples: int) -> States:
def sample_terminating_states(self, env: Env, n: int) -> States:
"""Rolls out the parametrization's policy and returns the terminating states.

Args:
env: the environment to sample terminating states from.
n_samples: number of terminating states to be sampled.
n: number of terminating states to be sampled.
Returns:
States: sampled terminating states object.
"""
trajectories = self.sample_trajectories(
env, n_samples, save_estimator_outputs=False, save_logprobs=False
env, n, save_estimator_outputs=False, save_logprobs=False
)
return trajectories.last_states

def logz_named_parameters(self):
return {"logZ": dict(self.named_parameters())["logZ"]}
return {k: v for k, v in dict(self.named_parameters()).items() if "logZ" in k}

def logz_parameters(self):
return [dict(self.named_parameters())["logZ"]]
return [v for k, v in dict(self.named_parameters()).items() if "logZ" in k]

@abstractmethod
def to_training_samples(self, trajectories: Trajectories) -> TrainingSampleType:
"""Converts trajectories to training samples. The type depends on the GFlowNet."""

@abstractmethod
def loss(self, env: Env, training_objects):
def loss(self, env: Env, training_objects: Any):
"""Computes the loss given the training objects."""


Expand All @@ -93,18 +96,20 @@ def __init__(self, pf: GFNModule, pb: GFNModule):
def sample_trajectories(
self,
env: Env,
n_samples: int,
n: int,
conditioning: torch.Tensor | None = None,
save_logprobs: bool = True,
save_estimator_outputs: bool = False,
**policy_kwargs,
**policy_kwargs: Any,
) -> Trajectories:
"""Samples trajectories, optionally with specified policy kwargs."""
sampler = Sampler(estimator=self.pf)
trajectories = sampler.sample_trajectories(
env,
n_trajectories=n_samples,
save_estimator_outputs=save_estimator_outputs,
n=n,
conditioning=conditioning,
save_logprobs=save_logprobs,
save_estimator_outputs=save_estimator_outputs,
**policy_kwargs,
)

Expand Down Expand Up @@ -176,7 +181,20 @@ def get_pfs_and_pbs(
~trajectories.actions.is_dummy
]
else:
estimator_outputs = self.pf(valid_states)
if trajectories.conditioning is not None:
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
masked_cond = trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~trajectories.states.is_sink_state]

# Here, we pass all valid states, i.e., non-sink states.
with has_conditioning_exception_handler("pf", self.pf):
estimator_outputs = self.pf(valid_states, masked_cond)
else:
# Here, we pass all valid states, i.e., non-sink states.
with no_conditioning_exception_handler("pf", self.pf):
estimator_outputs = self.pf(valid_states)

# Calculates the log PF of the actions sampled off policy.
valid_log_pf_actions = self.pf.to_probability_distribution(
Expand All @@ -196,7 +214,23 @@ def get_pfs_and_pbs(

# Using all non-initial states, calculate the backward policy, and the logprobs
# of those actions.
estimator_outputs = self.pb(non_initial_valid_states)
if trajectories.conditioning is not None:

# We need to index the conditioning vector to broadcast over the states.
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
masked_cond = trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~trajectories.states.is_sink_state][~valid_states.is_initial_state]

# Pass all valid states, i.e., non-sink states, except the initial state.
with has_conditioning_exception_handler("pb", self.pb):
estimator_outputs = self.pb(non_initial_valid_states, masked_cond)
else:
# Pass all valid states, i.e., non-sink states, except the initial state.
with no_conditioning_exception_handler("pb", self.pb):
estimator_outputs = self.pb(non_initial_valid_states)

valid_log_pb_actions = self.pb.to_probability_distribution(
non_initial_valid_states, estimator_outputs
).log_prob(non_exit_valid_actions.tensor)
Expand Down
Loading
Loading