Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix reverse_backward_trajectories and LocalSearchSampler for continuous case #233

Merged
merged 3 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 13 additions & 21 deletions src/gfn/containers/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,17 +432,7 @@ def reverse_backward_trajectories(
trajectories: Trajectories, debug: bool = False
) -> Trajectories:
"""Reverses a backward trajectory"""
# FIXME: This method is not compatible with continuous GFN.

assert trajectories.is_backward, "Trajectories must be backward."
new_actions = torch.full(
(
trajectories.max_length + 1,
len(trajectories),
*trajectories.actions.action_shape,
),
-1,
)

# env.sf should never be None unless something went wrong during class
# instantiation.
Expand All @@ -466,8 +456,8 @@ def reverse_backward_trajectories(
) # shape (max_len + 1, n_trajectories, *state_dim)

# Initialize new actions and states
new_actions = torch.full(
(max_len + 1, len(trajectories), *trajectories.actions.action_shape), -1
new_actions = trajectories.env.dummy_action.repeat(
max_len + 1, len(trajectories), 1
).to(
actions
) # shape (max_len + 1, n_trajectories, *action_dim)
Expand Down Expand Up @@ -504,9 +494,9 @@ def reverse_backward_trajectories(

# Assign reversed actions to new_actions
new_actions[:, :-1][mask] = actions[mask][rev_idx[mask]]
new_actions[torch.arange(len(trajectories)), seq_lengths] = (
trajectories.env.n_actions - 1
) # FIXME: This can be problematic if action_dim != 1 (e.g. continuous actions)
new_actions[
torch.arange(len(trajectories)), seq_lengths
] = trajectories.env.exit_action

# Assign reversed states to new_states
assert torch.all(states[:, -1] == trajectories.env.s0), "Last state must be s0"
Expand Down Expand Up @@ -539,19 +529,21 @@ def reverse_backward_trajectories(
# If `debug` is True (expected only when testing), compare the
# vectorized approach's results (above) to the for-loop results (below).
if debug:
_new_actions = torch.full(
(max_len + 1, len(trajectories), *trajectories.actions.action_shape), -1
).to(actions)
_new_actions = trajectories.env.dummy_action.repeat(
max_len + 1, len(trajectories), 1
).to(
actions
) # shape (max_len + 1, n_trajectories, *action_dim)
_new_states = trajectories.env.sf.repeat(
max_len + 2, len(trajectories), 1
).to(
states
) # shape (max_len + 2, n_trajectories, *state_dim)

for i in range(len(trajectories)):
_new_actions[trajectories.when_is_done[i], i] = (
trajectories.env.n_actions - 1
)
_new_actions[
trajectories.when_is_done[i], i
] = trajectories.env.exit_action
_new_actions[
: trajectories.when_is_done[i], i
] = trajectories.actions.tensor[: trajectories.when_is_done[i], i].flip(
Expand Down
24 changes: 11 additions & 13 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,6 @@ def _combine_prev_and_recon_trajectories( # noqa: C901

bs = prev_trajectories.n_trajectories
device = prev_trajectories.states.device
state_shape = prev_trajectories.states.state_shape
action_shape = prev_trajectories.env.action_shape
env = prev_trajectories.env

# Obtain full trajectories by concatenating the backward and forward parts.
Expand Down Expand Up @@ -590,12 +588,12 @@ def _combine_prev_and_recon_trajectories( # noqa: C901

# Prepare the new states and actions
# Note that these are initialized in transposed shapes
new_trajectories_states_tsr = torch.full(
(bs, max_traj_len + 1, *state_shape), -1
).to(prev_trajectories.states.tensor)
new_trajectories_actions_tsr = torch.full(
(bs, max_traj_len, *action_shape), -1
).to(prev_trajectories.actions.tensor)
new_trajectories_states_tsr = env.sf.repeat(bs, max_traj_len + 1, 1).to(
prev_trajectories.states.tensor
)
new_trajectories_actions_tsr = env.dummy_action.repeat(bs, max_traj_len, 1).to(
prev_trajectories.actions.tensor
)

# Assign the first part (backtracked from backward policy) of the trajectory
prev_mask_truc = prev_mask[:, :max_n_prev]
Expand Down Expand Up @@ -664,11 +662,11 @@ def _combine_prev_and_recon_trajectories( # noqa: C901
# If `debug` is True (expected only when testing), compare the
# vectorized approach's results (above) to the for-loop results (below).
if debug:
_new_trajectories_states_tsr = torch.full(
(max_traj_len + 1, bs, *state_shape), -1
).to(prev_trajectories.states.tensor)
_new_trajectories_actions_tsr = torch.full(
(max_traj_len, bs, *action_shape), -1
_new_trajectories_states_tsr = env.sf.repeat(max_traj_len + 1, bs, 1).to(
prev_trajectories.states.tensor
)
_new_trajectories_actions_tsr = env.dummy_action.repeat(
max_traj_len, bs, 1
).to(prev_trajectories.actions.tensor)

if save_logprobs:
Expand Down
1 change: 1 addition & 0 deletions tutorials/examples/test_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class BoxArgs(CommonArgs):
gamma_scheduler: float = 0.5
scheduler_milestone: int = 2500
lr_F: float = 1e-2
use_local_search: bool = False


@pytest.mark.parametrize("ndim", [2, 4])
Expand Down
46 changes: 43 additions & 3 deletions tutorials/examples/train_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
BoxStateFlowModule,
)
from gfn.modules import ScalarEstimator
from gfn.samplers import LocalSearchSampler, Sampler
from gfn.utils.common import set_seed

DEFAULT_SEED = 4444
Expand Down Expand Up @@ -179,6 +180,20 @@ def main(args): # noqa: C901
)

assert gflownet is not None, f"No gflownet for loss {args.loss}"
gflownet = gflownet.to(device_str)

if not args.use_local_search:
sampler = Sampler(estimator=pf_estimator)
local_search_params = {}
else:
sampler = LocalSearchSampler(
pf_estimator=pf_estimator, pb_estimator=pb_estimator
)
local_search_params = {
"n_local_search_loops": args.n_local_search_loops,
"back_ratio": args.back_ratio,
"use_metropolis_hastings": args.use_metropolis_hastings,
}

# 3. Create the optimizer and scheduler

Expand Down Expand Up @@ -226,13 +241,13 @@ def main(args): # noqa: C901
states_visited = 0

jsd = float("inf")
for iteration in trange(n_iterations):
for iteration in trange(n_iterations, dynamic_ncols=True):
if iteration % 1000 == 0:
print(f"current optimizer LR: {optimizer.param_groups[0]['lr']}")

# Sampling on-policy, so we save logprobs for faster computation.
trajectories = gflownet.sample_trajectories(
env, save_logprobs=True, n=args.batch_size
trajectories = sampler.sample_trajectories(
env, save_logprobs=True, n=args.batch_size, **local_search_params
)

training_samples = gflownet.to_training_samples(trajectories)
Expand Down Expand Up @@ -399,6 +414,31 @@ def main(args): # noqa: C901
help="Every scheduler_milestone steps, multiply the learning rate by gamma_scheduler",
)

parser.add_argument(
"--use_local_search",
action="store_true",
help="Use local search to sample the next state",
)

# Local search parameters.
parser.add_argument(
"--n_local_search_loops",
type=int,
default=2,
help="Number of local search loops",
)
parser.add_argument(
"--back_ratio",
type=float,
default=0.5,
help="The ratio of the number of backward steps to the length of the trajectory",
)
parser.add_argument(
"--use_metropolis_hastings",
action="store_true",
help="Use Metropolis-Hastings acceptance criterion",
)

parser.add_argument(
"--n_trajectories",
type=int,
Expand Down
Loading