Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warm-up functionality with tensor to trajectory helper functions #224

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def to_training_samples(self, trajectories: Trajectories) -> TrainingSampleType:
"""Converts trajectories to training samples. The type depends on the GFlowNet."""

@abstractmethod
def loss(self, env: Env, training_objects: Any):
def loss(self, env: Env, training_objects: Any) -> torch.Tensor:
"""Computes the loss given the training objects."""


Expand Down
75 changes: 72 additions & 3 deletions src/gfn/utils/training.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from collections import Counter
from typing import Dict, Optional

import torch
from tqdm import trange

from gfn.env import Env
from gfn.containers import ReplayBuffer
from gfn.env import Env, DiscreteEnv
from gfn.gflownet import GFlowNet, TBGFlowNet
from gfn.states import States
from gfn.samplers import Trajectories
from gfn.states import States, stack_states
import torch


def get_terminating_state_dist_pmf(env: Env, states: States) -> torch.Tensor:
Expand Down Expand Up @@ -74,3 +77,69 @@ def validate(
if logZ is not None:
validation_info["logZ_diff"] = abs(logZ - true_logZ)
return validation_info


def states_actions_tns_to_traj(
states_tns: torch.Tensor,
actions_tns: torch.Tensor,
env: DiscreteEnv,
) -> Trajectories:

# TODO shape assumption needs to be refined with torchgfn gang
# states_tns currently assumed to be the states of a single trajectory
# states_tns.shape is [traj_len, state_ndim]
# actions_tns is assumed to be the actions to the aforementioned traj
# actions_tns.shape is [traj_len]

states = [env.states_from_tensor(s.unsqueeze(0)) for s in states_tns]
actions = [
env.actions_from_tensor(a.unsqueeze(0).unsqueeze(0)) for a in actions_tns
]

# stack is a class method, so actions[0] is just to access a class instance and is not particularly relevant
actions = actions[0].stack(actions)
log_rewards = env.log_reward(states[-2])
states = stack_states(states)
when_is_done = torch.tensor([len(states_tns) - 1])

# WARNING: This is sketchy. Create dummy values to avoid indexing / batch shape errors.
# WARNING: Assumes gfn.loss() uses recalculate_all_logprobs=True (thus only TBGFNs are supported right now)!!
# WARNING: To reviewers: Can we bypass needing to define this?
log_probs = torch.full(size=(len(actions), 1), fill_value=0, dtype=torch.float)

estimator_outputs = torch.zeros((len(actions), 1, env.n_actions))
trajectory = Trajectories(
env,
states,
actions,
log_rewards=log_rewards,
when_is_done=when_is_done,
log_probs=log_probs,
estimator_outputs=estimator_outputs,
)
return trajectory


def warm_up(
replay_buf: ReplayBuffer,
optimizer: torch.optim.Optimizer,
gfn: GFlowNet,
env: Env,
n_steps: int,
batch_size: int,
recalculate_all_logprobs=True,
):
t = trange(n_steps, desc="Bar desc", leave=True)
for epoch in t:
training_trajs = replay_buf.sample(batch_size)
optimizer.zero_grad()
if isinstance(gfn, TBGFlowNet):
loss = gfn.loss(
env, training_trajs, recalculate_all_logprobs=recalculate_all_logprobs
)
else:
loss = gfn.loss(env, training_trajs)

loss.backward()
optimizer.step()
t.set_description(f"{epoch=}, {loss=}")