Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 13, 2024
1 parent 1c4ec37 commit db13637
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 1 addition & 3 deletions src/gfn/gflownet/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ def loss(
)
return fm_loss + self.alpha * rm_loss

def to_training_samples(
self, trajectories: Trajectories
) -> Union[
def to_training_samples(self, trajectories: Trajectories) -> Union[
Tuple[DiscreteStates, DiscreteStates, torch.Tensor, torch.Tensor],
Tuple[DiscreteStates, DiscreteStates, None, None],
Tuple[States, States, torch.Tensor, torch.Tensor],
Expand Down
6 changes: 5 additions & 1 deletion src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def sample_actions(
save_estimator_outputs: bool = False,
save_logprobs: bool = True,
**policy_kwargs: Any,
) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None,]:
) -> Tuple[
Actions,
torch.Tensor | None,
torch.Tensor | None,
]:
"""Samples actions from the given states.
Args:
Expand Down

0 comments on commit db13637

Please sign in to comment.