From 6aab6e039aaffe5796bfc220a5f9ee7764b95331 Mon Sep 17 00:00:00 2001 From: Joseph Date: Thu, 23 Nov 2023 11:48:03 -0500 Subject: [PATCH] black / isort --- src/gfn/gflownet/base.py | 6 ++++-- src/gfn/gflownet/detailed_balance.py | 2 +- src/gfn/gflownet/sub_trajectory_balance.py | 6 ++---- src/gfn/gflownet/trajectory_balance.py | 9 +++++---- src/gfn/samplers.py | 8 ++++++-- 5 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index cb20e59a..4cdb136e 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -1,6 +1,6 @@ +import math from abc import abstractmethod from typing import Tuple -import math import torch import torch.nn as nn @@ -158,7 +158,9 @@ def get_pfs_and_pbs( idx = torch.ones(trajectories.actions.batch_shape).bool() estimator_outputs = estimator_outputs[idx] except: - raise Exception("GFlowNet is off policy but no estimator_outputs found.") + raise Exception( + "GFlowNet is off policy but no estimator_outputs found." + ) # else: # estimator_outputs = self.pf(valid_states) diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index 24301b11..6e59f6c0 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -1,5 +1,5 @@ -from typing import Tuple import math +from typing import Tuple import torch from torchtyping import TensorType as TT diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index b7a3c753..05f8182c 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -1,5 +1,5 @@ -from typing import List, Literal, Tuple import math +from typing import List, Literal, Tuple import torch from torchtyping import TensorType as TT @@ -150,9 +150,7 @@ def get_scores( targets = torch.full_like(preds, fill_value=-float("inf")) assert trajectories.log_rewards is not None - log_rewards = trajectories.log_rewards[ - trajectories.when_is_done >= i - ] + log_rewards = trajectories.log_rewards[trajectories.when_is_done >= i] if math.isfinite(self.log_reward_clip_min): log_rewards.clamp_min(self.log_reward_clip_min) diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index 91d460a7..3fe80fa2 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -34,11 +34,13 @@ def __init__( pb: GFNModule, on_policy: bool = False, init_logZ: float = 0.0, - log_reward_clip_min: float = -float("inf") + log_reward_clip_min: float = -float("inf"), ): super().__init__(pf, pb, on_policy=on_policy) - self.logZ = nn.Parameter(torch.tensor(init_logZ)) # TODO: Optionally, this should be a nn.Module to support conditional GFNs. + self.logZ = nn.Parameter( + torch.tensor(init_logZ) + ) # TODO: Optionally, this should be a nn.Module to support conditional GFNs. self.log_reward_clip_min = log_reward_clip_min def loss( @@ -80,7 +82,7 @@ def __init__( pf: GFNModule, pb: GFNModule, on_policy: bool = False, - log_reward_clip_min: float = -float("inf") + log_reward_clip_min: float = -float("inf"), ): super().__init__(pf, pb, on_policy=on_policy) self.log_reward_clip_min = log_reward_clip_min @@ -90,7 +92,6 @@ def loss( env: Env, trajectories: Trajectories, estimator_outputs: torch.Tensor = None, - ) -> TT[0, float]: """Log Partition Variance loss. diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 31013c2d..fc7665da 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -131,7 +131,9 @@ def sample_trajectories( device = states.tensor.device dones = ( - states.is_initial_state if self.estimator.is_backward else states.is_sink_state + states.is_initial_state + if self.estimator.is_backward + else states.is_sink_state ) trajectories_states: List[TT["n_trajectories", "state_shape", torch.float]] = [ @@ -176,7 +178,9 @@ def sample_trajectories( trajectories_actions += [actions] trajectories_logprobs += [log_probs] - import IPython; IPython.embed() + import IPython + + IPython.embed() if self.estimator.is_backward: new_states = env.backward_step(states, actions) else: