Skip to content

Commit

Permalink
Merge pull request #188 from GFNOrg/conditional_gfn
Browse files Browse the repository at this point in the history
conditional gfn
  • Loading branch information
josephdviviano authored Oct 24, 2024
2 parents ac0832f + fd3d9dc commit d2d959e
Show file tree
Hide file tree
Showing 25 changed files with 861 additions and 161 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ module_PF = NeuralNet(
module_PB = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
torso=module_PF.torso # We share all the parameters of P_F and P_B, except for the last layer
trunk=module_PF.trunk # We share all the parameters of P_F and P_B, except for the last layer
)

# 3 - We define the estimators.
Expand Down Expand Up @@ -136,7 +136,7 @@ module_PF = NeuralNet(
module_PB = NeuralNet(
input_dim=env.preprocessor.output_dim,
output_dim=env.n_actions - 1,
torso=module_PF.torso # We share all the parameters of P_F and P_B, except for the last layer
trunk=module_PF.trunk # We share all the parameters of P_F and P_B, except for the last layer
)
module_logF = NeuralNet(
input_dim=env.preprocessor.output_dim,
Expand Down
41 changes: 37 additions & 4 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Sequence, Union, Tuple


if TYPE_CHECKING:
from gfn.actions import Actions
from gfn.env import Env
from gfn.states import States
from gfn.states import States, DiscreteStates

import numpy as np
import torch
Expand Down Expand Up @@ -50,6 +51,7 @@ def __init__(
self,
env: Env,
states: States | None = None,
conditioning: torch.Tensor | None = None,
actions: Actions | None = None,
when_is_done: TT["n_trajectories", torch.long] | None = None,
is_backward: bool = False,
Expand All @@ -76,6 +78,7 @@ def __init__(
is used to compute the rewards, at each call of self.log_rewards
"""
self.env = env
self.conditioning = conditioning
self.is_backward = is_backward
self.states = (
states if states is not None else env.states_from_batch_shape((0, 0))
Expand Down Expand Up @@ -315,6 +318,15 @@ def extend(self, other: Trajectories) -> None:

def to_transitions(self) -> Transitions:
"""Returns a `Transitions` object from the trajectories."""
if self.conditioning is not None:
traj_len = self.actions.batch_shape[0]
expand_dims = (traj_len,) + tuple(self.conditioning.shape)
conditioning = self.conditioning.unsqueeze(0).expand(expand_dims)[
~self.actions.is_dummy
]
else:
conditioning = None

states = self.states[:-1][~self.actions.is_dummy]
next_states = self.states[1:][~self.actions.is_dummy]
actions = self.actions[~self.actions.is_dummy]
Expand Down Expand Up @@ -348,6 +360,7 @@ def to_transitions(self) -> Transitions:
return Transitions(
env=self.env,
states=states,
conditioning=conditioning,
actions=actions,
is_done=is_done,
next_states=next_states,
Expand All @@ -363,7 +376,10 @@ def to_states(self) -> States:

def to_non_initial_intermediary_and_terminating_states(
self,
) -> tuple[States, States]:
) -> Union[
Tuple[States, States, torch.Tensor, torch.Tensor],
Tuple[States, States, None, None],
]:
"""Returns all intermediate and terminating `States` from the trajectories.
This is useful for the flow matching loss, that requires its inputs to be distinguished.
Expand All @@ -373,10 +389,27 @@ def to_non_initial_intermediary_and_terminating_states(
are not s0.
"""
states = self.states

if self.conditioning is not None:
traj_len = self.states.batch_shape[0]
expand_dims = (traj_len,) + tuple(self.conditioning.shape)
intermediary_conditioning = self.conditioning.unsqueeze(0).expand(
expand_dims
)[~states.is_sink_state & ~states.is_initial_state]
conditioning = self.conditioning # n_final_states == n_trajectories.
else:
intermediary_conditioning = None
conditioning = None

intermediary_states = states[~states.is_sink_state & ~states.is_initial_state]
terminating_states = self.last_states
terminating_states.log_rewards = self.log_rewards
return intermediary_states, terminating_states
return (
intermediary_states,
terminating_states,
intermediary_conditioning,
conditioning,
)


def pad_dim0_to_target(a: torch.Tensor, target_dim0: int) -> torch.Tensor:
Expand Down
2 changes: 2 additions & 0 deletions src/gfn/containers/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
self,
env: Env,
states: States | None = None,
conditioning: torch.Tensor | None = None,
actions: Actions | None = None,
is_done: TT["n_transitions", torch.bool] | None = None,
next_states: States | None = None,
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
`batch_shapes`.
"""
self.env = env
self.conditioning = conditioning
self.is_backward = is_backward
self.states = (
states
Expand Down
66 changes: 50 additions & 16 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import math
from abc import ABC, abstractmethod
from typing import Generic, Tuple, TypeVar, Union
from typing import Generic, Tuple, TypeVar, Union, Any

import torch
import torch.nn as nn
from torch import Tensor
from torchtyping import TensorType as TT

from gfn.containers import Trajectories
Expand All @@ -14,6 +13,10 @@
from gfn.samplers import Sampler
from gfn.states import States
from gfn.utils.common import has_log_probs
from gfn.utils.handlers import (
has_conditioning_exception_handler,
no_conditioning_exception_handler,
)

TrainingSampleType = TypeVar(
"TrainingSampleType", bound=Union[Container, tuple[States, ...]]
Expand All @@ -32,48 +35,48 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]):
def sample_trajectories(
self,
env: Env,
n_samples: int,
n: int,
save_logprobs: bool = True,
save_estimator_outputs: bool = False,
) -> Trajectories:
"""Sample a specific number of complete trajectories.
Args:
env: the environment to sample trajectories from.
n_samples: number of trajectories to be sampled.
n: number of trajectories to be sampled.
save_logprobs: whether to save the logprobs of the actions - useful for on-policy learning.
save_estimator_outputs: whether to save the estimator outputs - useful for off-policy learning
with tempered policy
Returns:
Trajectories: sampled trajectories object.
"""

def sample_terminating_states(self, env: Env, n_samples: int) -> States:
def sample_terminating_states(self, env: Env, n: int) -> States:
"""Rolls out the parametrization's policy and returns the terminating states.
Args:
env: the environment to sample terminating states from.
n_samples: number of terminating states to be sampled.
n: number of terminating states to be sampled.
Returns:
States: sampled terminating states object.
"""
trajectories = self.sample_trajectories(
env, n_samples, save_estimator_outputs=False, save_logprobs=False
env, n, save_estimator_outputs=False, save_logprobs=False
)
return trajectories.last_states

def logz_named_parameters(self):
return {"logZ": dict(self.named_parameters())["logZ"]}
return {k: v for k, v in dict(self.named_parameters()).items() if "logZ" in k}

def logz_parameters(self):
return [dict(self.named_parameters())["logZ"]]
return [v for k, v in dict(self.named_parameters()).items() if "logZ" in k]

@abstractmethod
def to_training_samples(self, trajectories: Trajectories) -> TrainingSampleType:
"""Converts trajectories to training samples. The type depends on the GFlowNet."""

@abstractmethod
def loss(self, env: Env, training_objects):
def loss(self, env: Env, training_objects: Any):
"""Computes the loss given the training objects."""


Expand All @@ -93,18 +96,20 @@ def __init__(self, pf: GFNModule, pb: GFNModule):
def sample_trajectories(
self,
env: Env,
n_samples: int,
n: int,
conditioning: torch.Tensor | None = None,
save_logprobs: bool = True,
save_estimator_outputs: bool = False,
**policy_kwargs,
**policy_kwargs: Any,
) -> Trajectories:
"""Samples trajectories, optionally with specified policy kwargs."""
sampler = Sampler(estimator=self.pf)
trajectories = sampler.sample_trajectories(
env,
n_trajectories=n_samples,
save_estimator_outputs=save_estimator_outputs,
n=n,
conditioning=conditioning,
save_logprobs=save_logprobs,
save_estimator_outputs=save_estimator_outputs,
**policy_kwargs,
)

Expand Down Expand Up @@ -176,7 +181,20 @@ def get_pfs_and_pbs(
~trajectories.actions.is_dummy
]
else:
estimator_outputs = self.pf(valid_states)
if trajectories.conditioning is not None:
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
masked_cond = trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~trajectories.states.is_sink_state]

# Here, we pass all valid states, i.e., non-sink states.
with has_conditioning_exception_handler("pf", self.pf):
estimator_outputs = self.pf(valid_states, masked_cond)
else:
# Here, we pass all valid states, i.e., non-sink states.
with no_conditioning_exception_handler("pf", self.pf):
estimator_outputs = self.pf(valid_states)

# Calculates the log PF of the actions sampled off policy.
valid_log_pf_actions = self.pf.to_probability_distribution(
Expand All @@ -196,7 +214,23 @@ def get_pfs_and_pbs(

# Using all non-initial states, calculate the backward policy, and the logprobs
# of those actions.
estimator_outputs = self.pb(non_initial_valid_states)
if trajectories.conditioning is not None:

# We need to index the conditioning vector to broadcast over the states.
cond_dim = (-1,) * len(trajectories.conditioning.shape)
traj_len = trajectories.states.tensor.shape[0]
masked_cond = trajectories.conditioning.unsqueeze(0).expand(
(traj_len,) + cond_dim
)[~trajectories.states.is_sink_state][~valid_states.is_initial_state]

# Pass all valid states, i.e., non-sink states, except the initial state.
with has_conditioning_exception_handler("pb", self.pb):
estimator_outputs = self.pb(non_initial_valid_states, masked_cond)
else:
# Pass all valid states, i.e., non-sink states, except the initial state.
with no_conditioning_exception_handler("pb", self.pb):
estimator_outputs = self.pb(non_initial_valid_states)

valid_log_pb_actions = self.pb.to_probability_distribution(
non_initial_valid_states, estimator_outputs
).log_prob(non_exit_valid_actions.tensor)
Expand Down
Loading

0 comments on commit d2d959e

Please sign in to comment.