diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 38072080..4d2f2354 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -203,9 +203,7 @@ def loss( ) return fm_loss + self.alpha * rm_loss - def to_training_samples( - self, trajectories: Trajectories - ) -> Union[ + def to_training_samples(self, trajectories: Trajectories) -> Union[ Tuple[DiscreteStates, DiscreteStates, torch.Tensor, torch.Tensor], Tuple[DiscreteStates, DiscreteStates, None, None], Tuple[States, States, torch.Tensor, torch.Tensor], diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index eb224fbf..0df9449f 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -35,7 +35,11 @@ def sample_actions( save_estimator_outputs: bool = False, save_logprobs: bool = True, **policy_kwargs: Any, - ) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None,]: + ) -> Tuple[ + Actions, + torch.Tensor | None, + torch.Tensor | None, + ]: """Samples actions from the given states. Args: