-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rethinking sampling #147
Rethinking sampling #147
Changes from all commits
95f4e01
7b0688b
0b812e5
6eeb0af
a391b28
d9fa884
2a16b1b
740e16f
9d1ffd0
963693c
76ab487
8999dd3
450ebf0
a8b637e
1acfcce
e052c82
f897aab
67ea36e
e88e1bb
b28dc95
a72afe9
46693af
119559d
e8ab999
5e87e3f
732fb0f
5048f3c
3b4e597
5ef0c22
d69d258
e29a278
ff45949
2ca4ced
a4c1786
f6edd53
6aab6e0
716ee7a
dfb929d
5d62bee
1ceb53d
b67a6d2
c419dd3
a0f43c6
d6ad17f
50b74d2
80b4c29
b5fdd32
ac42f22
05732a8
25235f0
0814725
72ac58f
b0432c9
b987a39
12eab45
cd35cb8
44050a9
038b67b
8388362
93f2e5f
a6601d7
71da6b5
bafa1ad
0990d51
aa3c656
e2ad9dd
be122ed
cfc560c
2bebde2
e7c7453
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,13 +7,20 @@ | |
from gfn.env import Env | ||
from gfn.states import States | ||
|
||
import numpy as np | ||
import torch | ||
from torch import Tensor | ||
from torchtyping import TensorType as TT | ||
|
||
from gfn.containers.base import Container | ||
from gfn.containers.transitions import Transitions | ||
|
||
|
||
def is_tensor(t) -> bool: | ||
"""Checks whether t is a torch.Tensor instance.""" | ||
return isinstance(t, Tensor) | ||
|
||
|
||
# TODO: remove env from this class? | ||
class Trajectories(Container): | ||
"""Container for complete trajectories (starting in $s_0$ and ending in $s_f$). | ||
|
@@ -47,16 +54,21 @@ def __init__( | |
is_backward: bool = False, | ||
log_rewards: TT["n_trajectories", torch.float] | None = None, | ||
log_probs: TT["max_length", "n_trajectories", torch.float] | None = None, | ||
estimator_outputs: torch.Tensor | None = None, | ||
) -> None: | ||
""" | ||
Args: | ||
env: The environment in which the trajectories are defined. | ||
states: The states of the trajectories. Defaults to None. | ||
actions: The actions of the trajectories. Defaults to None. | ||
when_is_done: The time step at which each trajectory ends. Defaults to None. | ||
is_backward: Whether the trajectories are backward or forward. Defaults to False. | ||
log_rewards: The log_rewards of the trajectories. Defaults to None. | ||
log_probs: The log probabilities of the trajectories' actions. Defaults to None. | ||
states: The states of the trajectories. | ||
actions: The actions of the trajectories. | ||
when_is_done: The time step at which each trajectory ends. | ||
is_backward: Whether the trajectories are backward or forward. | ||
log_rewards: The log_rewards of the trajectories. | ||
log_probs: The log probabilities of the trajectories' actions. | ||
estimator_outputs: When forward sampling off-policy for an n-step | ||
trajectory, n forward passes will be made on some function approximator, | ||
which may need to be re-used (for example, for evaluating PF). To avoid | ||
duplicated effort, the outputs of the forward passes can be stored here. | ||
|
||
If states is None, then the states are initialized to an empty States object, | ||
that can be populated on the fly. If log_rewards is None, then `env.log_reward` | ||
|
@@ -87,6 +99,7 @@ def __init__( | |
if log_probs is not None | ||
else torch.full(size=(0, 0), fill_value=0, dtype=torch.float) | ||
) | ||
self.estimator_outputs = estimator_outputs | ||
|
||
def __repr__(self) -> str: | ||
states = self.states.tensor.transpose(0, 1) | ||
|
@@ -154,6 +167,21 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories: | |
log_rewards = ( | ||
self._log_rewards[index] if self._log_rewards is not None else None | ||
) | ||
if is_tensor(self.estimator_outputs): | ||
# TODO: Is there a safer way to index self.estimator_outputs for | ||
# for n-dimensional estimator outputs? | ||
# | ||
# First we index along the first dimension of the estimator outputs. | ||
# This can be thought of as the instance dimension, and is | ||
# compatible with all supported indexing approaches (dim=1). | ||
# All dims > 1 are not explicitly indexed unless the dimensionality | ||
# of `index` matches all dimensions of `estimator_outputs` aside | ||
# from the first (trajectory) dimension. | ||
estimator_outputs = self.estimator_outputs[:, index] | ||
# Next we index along the trajectory length (dim=0) | ||
estimator_outputs = estimator_outputs[:new_max_length] | ||
else: | ||
estimator_outputs = None | ||
|
||
return Trajectories( | ||
env=self.env, | ||
|
@@ -163,6 +191,7 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories: | |
is_backward=self.is_backward, | ||
log_rewards=log_rewards, | ||
log_probs=log_probs, | ||
estimator_outputs=estimator_outputs, | ||
) | ||
|
||
@staticmethod | ||
|
@@ -198,7 +227,10 @@ def extend(self, other: Trajectories) -> None: | |
Args: | ||
other: an external set of Trajectories. | ||
""" | ||
if len(other) == 0: | ||
return | ||
|
||
# TODO: The replay buffer is storing `dones` - this wastes a lot of space. | ||
self.actions.extend(other.actions) | ||
self.states.extend(other.states) | ||
self.when_is_done = torch.cat((self.when_is_done, other.when_is_done), dim=0) | ||
|
@@ -213,11 +245,76 @@ def extend(self, other: Trajectories) -> None: | |
|
||
if self._log_rewards is not None and other._log_rewards is not None: | ||
self._log_rewards = torch.cat( | ||
(self._log_rewards, other._log_rewards), dim=0 | ||
(self._log_rewards, other._log_rewards), | ||
dim=0, | ||
) | ||
else: | ||
self._log_rewards = None | ||
|
||
# Either set, or append, estimator outputs if they exist in the submitted | ||
# trajectory. | ||
if self.estimator_outputs is None and is_tensor(other.estimator_outputs): | ||
self.estimator_outputs = other.estimator_outputs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but how would we match the indices of the trajectories to the indices of the estimator_outputs ? This feels dangerous. I suggest just throwing an error when one is None and the other is not (either one). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the idea is to be able to extend an empty I agree it is dangerous but I think we should support this behaviour. Admittedly it has been some time since I looked at this so I might be forgetting something. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fair enough! |
||
elif is_tensor(self.estimator_outputs) and is_tensor(other.estimator_outputs): | ||
batch_shape = self.actions.batch_shape | ||
n_bs = len(batch_shape) | ||
output_dtype = self.estimator_outputs.dtype | ||
|
||
if n_bs == 1: | ||
# Concatenate along the only batch dimension. | ||
self.estimator_outputs = torch.cat( | ||
(self.estimator_outputs, other.estimator_outputs), | ||
dim=0, | ||
) | ||
elif n_bs == 2: | ||
if self.estimator_outputs.shape[0] != other.estimator_outputs.shape[0]: | ||
# First we need to pad the first dimension on either self or other. | ||
self_shape = np.array(self.estimator_outputs.shape) | ||
other_shape = np.array(other.estimator_outputs.shape) | ||
required_first_dim = max(self_shape[0], other_shape[0]) | ||
|
||
# TODO: This should be a single reused function (#154) | ||
# The size of self needs to grow to match other along dim=0. | ||
if self_shape[0] < other_shape[0]: | ||
pad_dim = required_first_dim - self_shape[0] | ||
pad_dim_full = (pad_dim,) + tuple(self_shape[1:]) | ||
output_padding = torch.full( | ||
pad_dim_full, | ||
fill_value=-float("inf"), | ||
dtype=self.estimator_outputs.dtype, # TODO: This isn't working! Hence the cast below... | ||
device=self.estimator_outputs.device, | ||
) | ||
self.estimator_outputs = torch.cat( | ||
(self.estimator_outputs, output_padding), | ||
dim=0, | ||
) | ||
|
||
# The size of other needs to grow to match self along dim=0. | ||
if other_shape[0] < self_shape[0]: | ||
pad_dim = required_first_dim - other_shape[0] | ||
pad_dim_full = (pad_dim,) + tuple(other_shape[1:]) | ||
output_padding = torch.full( | ||
pad_dim_full, | ||
fill_value=-float("inf"), | ||
dtype=other.estimator_outputs.dtype, # TODO: This isn't working! Hence the cast below... | ||
device=other.estimator_outputs.device, | ||
) | ||
other.estimator_outputs = torch.cat( | ||
(other.estimator_outputs, output_padding), | ||
dim=0, | ||
) | ||
|
||
# Concatenate the tensors along the second dimension. | ||
self.estimator_outputs = torch.cat( | ||
(self.estimator_outputs, other.estimator_outputs), | ||
dim=1, | ||
).to( | ||
dtype=output_dtype | ||
) # Cast to prevent single precision becoming double precision... weird. | ||
|
||
# Sanity check. TODO: Remove? | ||
assert self.estimator_outputs.shape[:n_bs] == batch_shape | ||
|
||
def to_transitions(self) -> Transitions: | ||
"""Returns a `Transitions` object from the trajectories.""" | ||
states = self.states[:-1][~self.actions.is_dummy] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,4 @@ | ||
from abc import ABC, abstractmethod | ||
from copy import deepcopy | ||
from typing import Optional, Tuple, Union | ||
|
||
import torch | ||
|
@@ -8,6 +7,7 @@ | |
from gfn.actions import Actions | ||
from gfn.preprocessors import IdentityPreprocessor, Preprocessor | ||
from gfn.states import DiscreteStates, States | ||
from gfn.utils.common import set_seed | ||
|
||
# Errors | ||
NonValidActionsError = type("NonValidActionsError", (ValueError,), {}) | ||
|
@@ -23,7 +23,6 @@ def __init__( | |
sf: Optional[TT["state_shape", torch.float]] = None, | ||
device_str: Optional[str] = None, | ||
preprocessor: Optional[Preprocessor] = None, | ||
log_reward_clip: Optional[float] = -100.0, | ||
): | ||
"""Initializes an environment. | ||
|
||
|
@@ -37,7 +36,6 @@ def __init__( | |
preprocessor: a Preprocessor object that converts raw states to a tensor | ||
that can be fed into a neural network. Defaults to None, in which case | ||
the IdentityPreprocessor is used. | ||
log_reward_clip: Used to clip small rewards (in particular, log(0) rewards). | ||
""" | ||
self.device = torch.device(device_str) if device_str is not None else s0.device | ||
|
||
|
@@ -58,7 +56,6 @@ def __init__( | |
|
||
self.preprocessor = preprocessor | ||
self.is_discrete = False | ||
self.log_reward_clip = log_reward_clip | ||
|
||
@abstractmethod | ||
def make_States_class(self) -> type[States]: | ||
|
@@ -83,7 +80,7 @@ def reset( | |
assert not (random and sink) | ||
|
||
if random and seed is not None: | ||
torch.manual_seed(seed) | ||
set_seed(seed, performance_mode=True) | ||
|
||
if batch_shape is None: | ||
batch_shape = (1,) | ||
|
@@ -94,15 +91,15 @@ def reset( | |
) | ||
|
||
@abstractmethod | ||
def maskless_step( | ||
def maskless_step( # TODO: rename to step, other method becomes _step. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea ! |
||
self, states: States, actions: Actions | ||
) -> TT["batch_shape", "state_shape", torch.float]: | ||
"""Function that takes a batch of states and actions and returns a batch of next | ||
states. Does not need to check whether the actions are valid or the states are sink states. | ||
""" | ||
|
||
@abstractmethod | ||
def maskless_backward_step( | ||
def maskless_backward_step( # TODO: rename to backward_step, other method becomes _backward_step. | ||
self, states: States, actions: Actions | ||
) -> TT["batch_shape", "state_shape", torch.float]: | ||
"""Function that takes a batch of states and actions and returns a batch of previous | ||
|
@@ -134,7 +131,7 @@ def step( | |
) -> States: | ||
"""Function that takes a batch of states and actions and returns a batch of next | ||
states and a boolean tensor indicating sink states in the new batch.""" | ||
new_states = deepcopy(states) | ||
new_states = states.clone() # TODO: Ensure this is efficient! | ||
valid_states_idx: TT["batch_shape", torch.bool] = ~states.is_sink_state | ||
valid_actions = actions[valid_states_idx] | ||
valid_states = states[valid_states_idx] | ||
|
@@ -154,8 +151,6 @@ def step( | |
new_not_done_states_tensor = self.maskless_step( | ||
not_done_states, not_done_actions | ||
) | ||
# if isinstance(new_states, DiscreteStates): | ||
# new_not_done_states.masks = self.update_masks(not_done_states, not_done_actions) | ||
|
||
new_states.tensor[~new_sink_states_idx] = new_not_done_states_tensor | ||
|
||
|
@@ -168,7 +163,7 @@ def backward_step( | |
) -> States: | ||
"""Function that takes a batch of states and actions and returns a batch of next | ||
states and a boolean tensor indicating initial states in the new batch.""" | ||
new_states = deepcopy(states) | ||
new_states = states.clone() # TODO: Ensure this is efficient! | ||
valid_states_idx: TT["batch_shape", torch.bool] = ~new_states.is_initial_state | ||
valid_actions = actions[valid_states_idx] | ||
valid_states = states[valid_states_idx] | ||
|
@@ -197,8 +192,8 @@ def reward(self, final_states: States) -> TT["batch_shape", torch.float]: | |
raise NotImplementedError("Reward function is not implemented.") | ||
|
||
def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]: | ||
"""Calculates the log reward (clipping small rewards).""" | ||
return torch.log(self.reward(final_states)).clip(self.log_reward_clip) | ||
"""Calculates the log reward.""" | ||
return torch.log(self.reward(final_states)) | ||
|
||
@property | ||
def log_partition(self) -> float: | ||
|
@@ -224,7 +219,6 @@ def __init__( | |
sf: Optional[TT["state_shape", torch.float]] = None, | ||
device_str: Optional[str] = None, | ||
preprocessor: Optional[Preprocessor] = None, | ||
log_reward_clip: Optional[float] = -100.0, | ||
): | ||
"""Initializes a discrete environment. | ||
|
||
|
@@ -234,12 +228,10 @@ def __init__( | |
sf: The final state tensor (shared among all trajectories). | ||
device_str: String representation of a torch.device. | ||
preprocessor: An optional preprocessor for intermediate states. | ||
log_reward_clip: Used to clip small rewards (in particular, log(0) rewards). | ||
""" | ||
self.n_actions = n_actions | ||
super().__init__(s0, sf, device_str, preprocessor, log_reward_clip) | ||
super().__init__(s0, sf, device_str, preprocessor) | ||
self.is_discrete = True | ||
self.log_reward_clip = log_reward_clip | ||
|
||
def make_Actions_class(self) -> type[Actions]: | ||
env = self | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implicitly assumes that
self.estimator_outputs
is of shapemax_length x n_trajectories
(as is the case for example forself.log_probs
). Would this always be the case?I feel like things would easily break here unless we force some structure on
estimator_outputs
. Rather thantorch.Tensor
, it has to be someTensorType
with a specific shape IMO.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think of simply:
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that should work !