Skip to content

Commit

Permalink
black / isort
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 23, 2023
1 parent f6edd53 commit 6aab6e0
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 13 deletions.
6 changes: 4 additions & 2 deletions src/gfn/gflownet/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from abc import abstractmethod
from typing import Tuple
import math

import torch
import torch.nn as nn
Expand Down Expand Up @@ -158,7 +158,9 @@ def get_pfs_and_pbs(
idx = torch.ones(trajectories.actions.batch_shape).bool()
estimator_outputs = estimator_outputs[idx]
except:
raise Exception("GFlowNet is off policy but no estimator_outputs found.")
raise Exception(
"GFlowNet is off policy but no estimator_outputs found."
)
# else:
# estimator_outputs = self.pf(valid_states)

Expand Down
2 changes: 1 addition & 1 deletion src/gfn/gflownet/detailed_balance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Tuple
import math
from typing import Tuple

import torch
from torchtyping import TensorType as TT
Expand Down
6 changes: 2 additions & 4 deletions src/gfn/gflownet/sub_trajectory_balance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import List, Literal, Tuple
import math
from typing import List, Literal, Tuple

import torch
from torchtyping import TensorType as TT
Expand Down Expand Up @@ -150,9 +150,7 @@ def get_scores(

targets = torch.full_like(preds, fill_value=-float("inf"))
assert trajectories.log_rewards is not None
log_rewards = trajectories.log_rewards[
trajectories.when_is_done >= i
]
log_rewards = trajectories.log_rewards[trajectories.when_is_done >= i]

if math.isfinite(self.log_reward_clip_min):
log_rewards.clamp_min(self.log_reward_clip_min)
Expand Down
9 changes: 5 additions & 4 deletions src/gfn/gflownet/trajectory_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ def __init__(
pb: GFNModule,
on_policy: bool = False,
init_logZ: float = 0.0,
log_reward_clip_min: float = -float("inf")
log_reward_clip_min: float = -float("inf"),
):
super().__init__(pf, pb, on_policy=on_policy)

self.logZ = nn.Parameter(torch.tensor(init_logZ)) # TODO: Optionally, this should be a nn.Module to support conditional GFNs.
self.logZ = nn.Parameter(
torch.tensor(init_logZ)
) # TODO: Optionally, this should be a nn.Module to support conditional GFNs.
self.log_reward_clip_min = log_reward_clip_min

def loss(
Expand Down Expand Up @@ -80,7 +82,7 @@ def __init__(
pf: GFNModule,
pb: GFNModule,
on_policy: bool = False,
log_reward_clip_min: float = -float("inf")
log_reward_clip_min: float = -float("inf"),
):
super().__init__(pf, pb, on_policy=on_policy)
self.log_reward_clip_min = log_reward_clip_min
Expand All @@ -90,7 +92,6 @@ def loss(
env: Env,
trajectories: Trajectories,
estimator_outputs: torch.Tensor = None,

) -> TT[0, float]:
"""Log Partition Variance loss.
Expand Down
8 changes: 6 additions & 2 deletions src/gfn/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def sample_trajectories(
device = states.tensor.device

dones = (
states.is_initial_state if self.estimator.is_backward else states.is_sink_state
states.is_initial_state
if self.estimator.is_backward
else states.is_sink_state
)

trajectories_states: List[TT["n_trajectories", "state_shape", torch.float]] = [
Expand Down Expand Up @@ -176,7 +178,9 @@ def sample_trajectories(
trajectories_actions += [actions]
trajectories_logprobs += [log_probs]

import IPython; IPython.embed()
import IPython

IPython.embed()
if self.estimator.is_backward:
new_states = env.backward_step(states, actions)
else:
Expand Down

0 comments on commit 6aab6e0

Please sign in to comment.