diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py index f04b9408..315c6947 100644 --- a/tutorials/examples/train_line.py +++ b/tutorials/examples/train_line.py @@ -208,13 +208,12 @@ def log_prob(self, sampled_actions): # TODO: Continous Timestamp Environmemt Subclass. if self.backward: exit_idx = self.states[..., 1].tensor == 1 # This is the s1->s0 action. - actions_to_eval[~exit_idx] = sampled_actions[~exit_idx] - logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[~exit_idx] # TODO: inefficient! else: # Forward: handle exit actions. - exit_idx = torch.all(sampled_actions == torch.full_like(sampled_actions[0], -float("inf")), 1) # This is the exit action. - actions_to_eval[~exit_idx] = sampled_actions[~exit_idx] - if sum(~exit_idx) > 0: - logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[~exit_idx] # TODO: inefficient! + exit_idx = torch.all(sampled_actions == -float("inf"), 1) # This is the sn->sf action. + + actions_to_eval[~exit_idx] = sampled_actions[~exit_idx] + if sum(~exit_idx) > 0: + logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[~exit_idx] # TODO: inefficient! return logprobs.squeeze(-1) @@ -425,7 +424,7 @@ def train( env = Line(mus=[2, 5], variances=[0.2, 0.2], init_value=0, n_sd=4.5, n_steps_per_trajectory=5) # Forward and backward policy estimators. We pass the lower bound from the env here. hid_dim = 64 - n_hidden_layers = 1 + n_hidden_layers = 2 policy_std_min = 0.1 policy_std_max = 1 exploration_var_starting_val = 2 @@ -465,8 +464,8 @@ def train( gflownet, jsd = train( gflownet, lr_base=1e-3, - n_trajectories=1e6, - batch_size=256, + n_trajectories=1.28e6, + batch_size=1, # 256 exploration_var_starting_val=exploration_var_starting_val ) # I started training this with 1e-3 and then reduced it.