From 5ce1fdc80b2d841186877203e9f244d4e3d7653d Mon Sep 17 00:00:00 2001 From: hyeok9855 Date: Fri, 6 Dec 2024 16:34:58 +0900 Subject: [PATCH] Minor refactors based on review --- src/gfn/samplers.py | 8 ++++---- tutorials/examples/train_hypergrid_simple_ls.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index ad126cc..9480f74 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -247,10 +247,10 @@ def sample_trajectories( trajectories_states = stack_states(trajectories_states) trajectories_actions = env.Actions.stack(trajectories_actions)[ - 1:, ... # Drop dummy action + 1: # Drop dummy action ] trajectories_logprobs = ( - torch.stack(trajectories_logprobs, dim=0)[1:, ...] # Drop dummy logprob + torch.stack(trajectories_logprobs, dim=0)[1:] # Drop dummy logprob if save_logprobs else None ) @@ -546,7 +546,7 @@ def sample_trajectories( n = trajectories.n_trajectories search_indices = torch.arange(n, device=trajectories.states.device) - for it in range(n_local_search_loops - 1): + for it in range(1, n_local_search_loops): # 0-th loop is the initial sampling # Search phase ls_trajectories, is_updated = self.local_search( env, @@ -562,7 +562,7 @@ def sample_trajectories( trajectories.extend(ls_trajectories) last_indices = torch.arange( - n * (it + 1), n * (it + 2), device=trajectories.states.device + n * it, n * (it + 1), device=trajectories.states.device ) search_indices[is_updated] = last_indices[is_updated] diff --git a/tutorials/examples/train_hypergrid_simple_ls.py b/tutorials/examples/train_hypergrid_simple_ls.py index 8fcc65e..382f7ef 100644 --- a/tutorials/examples/train_hypergrid_simple_ls.py +++ b/tutorials/examples/train_hypergrid_simple_ls.py @@ -38,7 +38,7 @@ def main(args): gflownet = TBGFlowNet(pf=pf_estimator, pb=pb_estimator, logZ=0.0) # Feed pf to the sampler. - sampler = LocalSearchSampler(estimator=pf_estimator, pb_estimator=pb_estimator) + sampler = LocalSearchSampler(pf_estimator=pf_estimator, pb_estimator=pb_estimator) # Move the gflownet to the GPU. gflownet = gflownet.to(device_str)