From db13637daeb1c1e53ba3ded8111c33dcfb16ba3c Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Wed, 13 Nov 2024 15:41:14 -0500 Subject: [PATCH] black --- src/gfn/gflownet/flow_matching.py | 4 +--- src/gfn/samplers.py | 6 +++++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 38072080..4d2f2354 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -203,9 +203,7 @@ def loss( ) return fm_loss + self.alpha * rm_loss - def to_training_samples( - self, trajectories: Trajectories - ) -> Union[ + def to_training_samples(self, trajectories: Trajectories) -> Union[ Tuple[DiscreteStates, DiscreteStates, torch.Tensor, torch.Tensor], Tuple[DiscreteStates, DiscreteStates, None, None], Tuple[States, States, torch.Tensor, torch.Tensor], diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index eb224fbf..0df9449f 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -35,7 +35,11 @@ def sample_actions( save_estimator_outputs: bool = False, save_logprobs: bool = True, **policy_kwargs: Any, - ) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None,]: + ) -> Tuple[ + Actions, + torch.Tensor | None, + torch.Tensor | None, + ]: """Samples actions from the given states. Args: