Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Feb 16, 2024
2 parents c393014 + eedc7e8 commit 3cb9914
Show file tree
Hide file tree
Showing 16 changed files with 362 additions and 789 deletions.
875 changes: 201 additions & 674 deletions LICENSE

Large diffs are not rendered by default.

20 changes: 16 additions & 4 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
self.env = env
self.is_backward = is_backward
self.states = (
states.clone() # TODO: Do we need this clone?
states
if states is not None
else env.states_from_batch_shape((0, 0))
)
Expand Down Expand Up @@ -160,9 +160,18 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories:
log_rewards = (
self._log_rewards[index] if self._log_rewards is not None else None
)

if isinstance(self.estimator_outputs, Tensor):
if is_tensor(self.estimator_outputs):
# TODO: Is there a safer way to index self.estimator_outputs for
# for n-dimensional estimator outputs?
#
# First we index along the first dimension of the estimator outputs.
# This can be thought of as the instance dimension, and is
# compatible with all supported indexing approaches (dim=1).
# All dims > 1 are not explicitly indexed unless the dimensionality
# of `index` matches all dimensions of `estimator_outputs` aside
# from the first (trajectory) dimension.
estimator_outputs = self.estimator_outputs[:, index]
# Next we index along the trajectory length (dim=0)
estimator_outputs = estimator_outputs[:new_max_length]
else:
estimator_outputs = None
Expand Down Expand Up @@ -211,6 +220,9 @@ def extend(self, other: Trajectories) -> None:
Args:
other: an external set of Trajectories.
"""
if len(other) == 0:
return

# TODO: The replay buffer is storing `dones` - this wastes a lot of space.
self.actions.extend(other.actions)
self.states.extend(other.states)
Expand Down Expand Up @@ -258,7 +270,7 @@ def extend(self, other: Trajectories) -> None:
other_shape = np.array(other.estimator_outputs.shape)
required_first_dim = max(self_shape[0], other_shape[0])

# TODO: This should be a single reused function.
# TODO: This should be a single reused function (#154)
# The size of self needs to grow to match other along dim=0.
if self_shape[0] < other_shape[0]:
pad_dim = required_first_dim - self_shape[0]
Expand Down
10 changes: 5 additions & 5 deletions src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from gfn.actions import Actions
from gfn.preprocessors import IdentityPreprocessor, Preprocessor
from gfn.states import DiscreteStates, States
from gfn.utils.common import set_seed

# Errors
NonValidActionsError = type("NonValidActionsError", (ValueError,), {})
Expand Down Expand Up @@ -171,7 +172,7 @@ def reset(
assert not (random and sink)

if random and seed is not None:
torch.manual_seed(seed) # TODO: Improve seeding here?
set_seed(seed, performance_mode=True)

if batch_shape is None:
batch_shape = (1,)
Expand Down Expand Up @@ -217,10 +218,9 @@ def _step(
not_done_states = new_states[~new_sink_states_idx]
not_done_actions = actions[~new_sink_states_idx]

new_not_done_states_tensor = self.step(not_done_states, not_done_actions)
# TODO: Why is this here? Should it be removed?
# if isinstance(new_states, DiscreteStates):
# new_not_done_states.masks = self.update_masks(not_done_states, not_done_actions)
new_not_done_states_tensor = self.maskless_step(
not_done_states, not_done_actions
)

new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor

Expand Down
24 changes: 13 additions & 11 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@ class GFlowNet(ABC, nn.Module, Generic[TrainingSampleType]):
"""

@abstractmethod
def sample_trajectories(self, env: Env, n_samples: int) -> Trajectories:
def sample_trajectories(
self, env: Env, n_samples: int, sample_off_policy: bool
) -> Trajectories:
"""Sample a specific number of complete trajectories.
Args:
env: the environment to sample trajectories from.
n_samples: number of trajectories to be sampled.
sample_off_policy: whether to sample trajectories on / off policy.
Returns:
Trajectories: sampled trajectories object.
"""
Expand All @@ -48,12 +51,6 @@ def sample_terminating_states(self, env: Env, n_samples: int) -> States:
trajectories = self.sample_trajectories(env, n_samples, sample_off_policy=False)
return trajectories.last_states

def pf_pb_named_parameters(self):
return {k: v for k, v in self.named_parameters() if "pb" in k or "pf" in k}

def pf_pb_parameters(self):
return [v for k, v in self.named_parameters() if "pb" in k or "pf" in k]

def logz_named_parameters(self):
return {"logZ": dict(self.named_parameters())["logZ"]}

Expand Down Expand Up @@ -97,6 +94,12 @@ def sample_trajectories(

return trajectories

def pf_pb_named_parameters(self):
return {k: v for k, v in self.named_parameters() if "pb" in k or "pf" in k}

def pf_pb_parameters(self):
return [v for k, v in self.named_parameters() if "pb" in k or "pf" in k]


class TrajectoryBasedGFlowNet(PFBasedGFlowNet[Trajectories]):
def get_pfs_and_pbs(
Expand Down Expand Up @@ -148,7 +151,7 @@ def get_pfs_and_pbs(

if self.off_policy:
# We re-use the values calculated in .sample_trajectories().
if not isinstance(trajectories.estimator_outputs, type(None)):
if trajectories.estimator_outputs is not None:
estimator_outputs = trajectories.estimator_outputs[
~trajectories.actions.is_dummy
]
Expand Down Expand Up @@ -211,9 +214,8 @@ def get_trajectories_scores(
total_log_pb_trajectories = log_pb_trajectories.sum(dim=0)

log_rewards = trajectories.log_rewards
if math.isfinite(self.log_reward_clip_min) and not isinstance(
log_rewards, type(None)
):
# TODO: log_reward_clip_min isn't defined in base (#155).
if math.isfinite(self.log_reward_clip_min) and log_rewards is not None:
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)

if torch.any(torch.isinf(total_log_pf_trajectories)) or torch.any(
Expand Down
15 changes: 9 additions & 6 deletions src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def __init__(
logF: ScalarEstimator,
off_policy: bool,
forward_looking: bool = False,
log_reward_clamp_min: float = -float("inf"),
log_reward_clip_min: float = -float("inf"),
):
super().__init__(pf, pb, off_policy=off_policy)
self.logF = logF
self.forward_looking = forward_looking
self.log_reward_clamp_min = log_reward_clamp_min
self.log_reward_clip_min = log_reward_clip_min

def get_scores(
self, env: Env, transitions: Transitions
Expand Down Expand Up @@ -68,10 +68,13 @@ def get_scores(

if states.batch_shape != tuple(actions.batch_shape):
raise ValueError("Something wrong happening with log_pf evaluations")
if self.off_policy:
if not self.off_policy:
valid_log_pf_actions = transitions.log_probs
else:
# Evaluate the log PF of the actions sampled off policy.
# I suppose the Transitions container should then have some
# estimator_outputs attribute as well, to avoid duplication here ?
# See (#156).
module_output = self.pf(states) # TODO: Inefficient duplication.
valid_log_pf_actions = self.pf.to_probability_distribution(
states, module_output
Expand All @@ -82,8 +85,8 @@ def get_scores(
valid_log_F_s = self.logF(states).squeeze(-1)
if self.forward_looking:
log_rewards = env.log_reward(states) # TODO: RM unsqueeze(-1) ?
if math.isfinite(self.log_reward_clamp_min):
log_rewards = log_rewards.clamp_min(self.log_reward_clamp_min)
if math.isfinite(self.log_reward_clip_min):
log_rewards = log_rewards.clamp_min(self.log_reward_clip_min)
valid_log_F_s = valid_log_F_s + log_rewards

preds = valid_log_pf_actions + valid_log_F_s
Expand Down Expand Up @@ -163,7 +166,7 @@ def get_scores(self, transitions: Transitions) -> TT["n_trajectories", torch.flo
all_log_rewards = transitions.all_log_rewards[mask]
module_output = self.pf(states)
pf_dist = self.pf.to_probability_distribution(states, module_output)
if self.off_policy:
if not self.off_policy:
valid_log_pf_actions = transitions[mask].log_probs
else:
# Evaluate the log PF of the actions sampled off policy.
Expand Down
15 changes: 9 additions & 6 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def sample_trajectories(
off_policy: bool,
states: Optional[States] = None,
n_trajectories: Optional[int] = None,
test_mode: bool = False,
debug_mode: bool = False,
**policy_kwargs,
) -> Trajectories:
"""Sample trajectories sequentially.
Expand All @@ -110,16 +110,16 @@ def sample_trajectories(
parameter, `epsilon`, and `sf_bias`. In the continuous case these
kwargs will be user defined. This can be used to, for example, sample
off-policy.
test_mode: if True, everything gets calculated.
debug_mode: if True, everything gets calculated.
Returns: A Trajectories object representing the batch of sampled trajectories.
Raises:
AssertionError: When both states and n_trajectories are specified.
AssertionError: When states are not linear.
"""
save_estimator_outputs = off_policy or test_mode
skip_logprob_calculaion = off_policy and not test_mode
save_estimator_outputs = off_policy or debug_mode
skip_logprob_calculaion = off_policy and not debug_mode

if states is None:
assert (
Expand Down Expand Up @@ -171,7 +171,7 @@ def sample_trajectories(
calculate_logprobs=False if skip_logprob_calculaion else True,
**policy_kwargs,
)
if not isinstance(estimator_outputs, type(None)):
if estimator_outputs is not None:
# Place estimator outputs into a stackable tensor. Note that this
# will be replaced with torch.nested.nested_tensor in the future.
estimator_outputs_padded = torch.full(
Expand Down Expand Up @@ -199,11 +199,14 @@ def sample_trajectories(
# Increment the step, determine which trajectories are finisihed, and eval
# rewards.
step += 1
# new_dones means those trajectories that just finished. Because we
# pad the sink state to every short trajectory, we need to make sure
# to filter out the already done ones.
new_dones = (
new_states.is_initial_state
if self.estimator.is_backward
else sink_states_mask
) & ~dones # TODO: why is ~dones used here and again later on? Is this intentional?
) & ~dones
trajectories_dones[new_dones & ~dones] = step
try:
trajectories_log_rewards[new_dones & ~dones] = env.log_reward(
Expand Down
8 changes: 4 additions & 4 deletions src/gfn/states.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations # This allows to use the class name in type hints

from abc import ABC, abstractmethod
from copy import deepcopy
from math import prod
from typing import Callable, ClassVar, Optional, Sequence, cast

Expand Down Expand Up @@ -129,7 +130,7 @@ def __getitem__(self, index: int | Sequence[int] | Sequence[bool]) -> States:
"""Access particular states of the batch."""
return self.__class__(
self.tensor[index]
) # TODO: Inefficient - this makes a copy of the tensor!
) # TODO: Inefficient - this might make a copy of the tensor!

def __setitem__(
self, index: int | Sequence[int] | Sequence[bool], states: States
Expand All @@ -138,9 +139,8 @@ def __setitem__(
self.tensor[index] = states.tensor

def clone(self) -> States:
"""Returns a clone of the current instance."""
# TODO: Do we need to copy _log_rewards?
return self.__class__(self.tensor.detach().clone())
"""Returns a *detached* clone of the current instance using deepcopy."""
return deepcopy(self)

def flatten(self) -> States:
"""Flatten the batch dimension of the states.
Expand Down
77 changes: 6 additions & 71 deletions src/gfn/utils/common.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,17 @@
import random
from collections import Counter
from typing import Dict, Optional

import numpy as np
import torch
from torchtyping import TensorType as TT

from gfn.containers import Trajectories, Transitions
from gfn.env import Env
from gfn.gflownet import GFlowNet, TBGFlowNet
from gfn.states import States


def get_terminating_state_dist_pmf(env: Env, states: States) -> TT["n_states", float]:
states_indices = env.get_terminating_states_indices(states).cpu().numpy().tolist()
counter = Counter(states_indices)
counter_list = [
counter[state_idx] if state_idx in counter else 0
for state_idx in range(env.n_terminating_states)
]

return torch.tensor(counter_list, dtype=torch.float) / len(states_indices)


def validate(
env: Env,
gflownet: GFlowNet,
n_validation_samples: int = 1000,
visited_terminating_states: Optional[States] = None,
) -> Dict[str, float]:
"""Evaluates the current gflownet on the given environment.
This is for environments with known target reward. The validation is done by
computing the l1 distance between the learned empirical and the target
distributions.
Args:
env: The environment to evaluate the gflownet on.
gflownet: The gflownet to evaluate.
n_validation_samples: The number of samples to use to evaluate the pmf.
visited_terminating_states: The terminating states visited during training. If given, the pmf is obtained from
these last n_validation_samples states. Otherwise, n_validation_samples are resampled for evaluation.
Returns: A dictionary containing the l1 validation metric. If the gflownet
is a TBGFlowNet, i.e. contains LogZ, then the (absolute) difference
between the learned and the target LogZ is also returned in the dictionary.
"""

true_logZ = env.log_partition
true_dist_pmf = env.true_dist_pmf
if isinstance(true_dist_pmf, torch.Tensor):
true_dist_pmf = true_dist_pmf.cpu()
else:
# The environment does not implement a true_dist_pmf property, nor a log_partition property
# We cannot validate the gflownet
return {}

logZ = None
if isinstance(gflownet, TBGFlowNet):
logZ = gflownet.logZ.item()
if visited_terminating_states is None:
terminating_states = gflownet.sample_terminating_states(n_validation_samples)
else:
terminating_states = visited_terminating_states[-n_validation_samples:]

final_states_dist_pmf = get_terminating_state_dist_pmf(env, terminating_states)
l1_dist = (final_states_dist_pmf - true_dist_pmf).abs().mean().item()
validation_info = {"l1_dist": l1_dist}
if logZ is not None:
validation_info["logZ_diff"] = abs(logZ - true_logZ)
return validation_info


def set_seed(seed: int) -> None:
def set_seed(seed: int, performance_mode: bool = False) -> None:
"""Used to control randomness."""
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# These are only set when we care about reproducibility over performance.
if not performance_mode:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
2 changes: 1 addition & 1 deletion src/gfn/utils/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
arch.append(nn.Linear(hidden_dim, hidden_dim))
arch.append(activation())
self.torso = nn.Sequential(*arch)
self.torso.hidden_dim = hidden_dim # TODO: what is this?
self.torso.hidden_dim = hidden_dim
else:
self.torso = torso
self.last_layer = nn.Linear(self.torso.hidden_dim, output_dim)
Expand Down
Loading

0 comments on commit 3cb9914

Please sign in to comment.