diff --git a/src/gfn/gym/line.py b/src/gfn/gym/line.py index 6b246f94..a0b2534e 100644 --- a/src/gfn/gym/line.py +++ b/src/gfn/gym/line.py @@ -42,7 +42,7 @@ def __init__( action_shape=(1,), # [x_pos] dummy_action=dummy_action, exit_action=exit_action, - ) # sf is -inf by defaukt. + ) # sf is -inf by default. def step( self, states: States, actions: Actions @@ -73,11 +73,6 @@ def is_action_valid( def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]: s = final_states.tensor[..., 0] - # return torch.logsumexp(torch.stack([m.log_prob(s) for m in self.mixture], 0), 0) - - # if s.nelement() == 0: - # return torch.zeros(final_states.batch_shape) - log_rewards = torch.empty((len(self.mixture),) + final_states.batch_shape) for i, m in enumerate(self.mixture): log_rewards[i] = m.log_prob(s)