Skip to content

Commit

Permalink
simplified logprobs calc
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 21, 2023
1 parent f897aab commit 67ea36e
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions tutorials/examples/train_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,12 @@ def log_prob(self, sampled_actions):
# TODO: Continous Timestamp Environmemt Subclass.
if self.backward:
exit_idx = self.states[..., 1].tensor == 1 # This is the s1->s0 action.
actions_to_eval[~exit_idx] = sampled_actions[~exit_idx]
logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[~exit_idx] # TODO: inefficient!
else: # Forward: handle exit actions.
exit_idx = torch.all(sampled_actions == torch.full_like(sampled_actions[0], -float("inf")), 1) # This is the exit action.
actions_to_eval[~exit_idx] = sampled_actions[~exit_idx]
if sum(~exit_idx) > 0:
logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[~exit_idx] # TODO: inefficient!
exit_idx = torch.all(sampled_actions == -float("inf"), 1) # This is the sn->sf action.

actions_to_eval[~exit_idx] = sampled_actions[~exit_idx]
if sum(~exit_idx) > 0:
logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[~exit_idx] # TODO: inefficient!

return logprobs.squeeze(-1)

Expand Down Expand Up @@ -425,7 +424,7 @@ def train(
env = Line(mus=[2, 5], variances=[0.2, 0.2], init_value=0, n_sd=4.5, n_steps_per_trajectory=5)
# Forward and backward policy estimators. We pass the lower bound from the env here.
hid_dim = 64
n_hidden_layers = 1
n_hidden_layers = 2
policy_std_min = 0.1
policy_std_max = 1
exploration_var_starting_val = 2
Expand Down Expand Up @@ -465,8 +464,8 @@ def train(
gflownet, jsd = train(
gflownet,
lr_base=1e-3,
n_trajectories=1e6,
batch_size=256,
n_trajectories=1.28e6,
batch_size=1, # 256
exploration_var_starting_val=exploration_var_starting_val
) # I started training this with 1e-3 and then reduced it.

Expand Down

0 comments on commit 67ea36e

Please sign in to comment.