Skip to content

Commit

Permalink
black and autoflake reformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
saleml committed Oct 25, 2023
1 parent 382f396 commit 77abec6
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/gfn/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
sf: Optional[TT["state_shape", torch.float]] = None,
device_str: Optional[str] = None,
preprocessor: Optional[Preprocessor] = None,
log_reward_clip: Optional[float] = -100.,
log_reward_clip: Optional[float] = -100.0,
):
"""Initializes an environment.
Expand Down
2 changes: 1 addition & 1 deletion src/gfn/gym/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
R2: float = 2.0,
epsilon: float = 1e-4,
device_str: Literal["cpu", "cuda"] = "cpu",
log_reward_clip: float = -100.,
log_reward_clip: float = -100.0,
):
assert 0 < delta <= 1, "delta must be in (0, 1]"
self.delta = delta
Expand Down
4 changes: 2 additions & 2 deletions src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import ClassVar, Literal, Tuple, cast
from typing import ClassVar, Literal, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(
alpha: float = 1.0,
device_str: Literal["cpu", "cuda"] = "cpu",
preprocessor_name: Literal["Identity", "Enum"] = "Identity",
log_reward_clip: float = -100.,
log_reward_clip: float = -100.0,
):
"""Discrete EBM environment.
Expand Down
4 changes: 2 additions & 2 deletions src/gfn/gym/hypergrid.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Copied and Adapted from https://github.com/Tikquuss/GflowNets_Tutorial
"""
from typing import ClassVar, Literal, Tuple, cast
from typing import ClassVar, Literal, Tuple

import torch
from einops import rearrange
Expand All @@ -25,7 +25,7 @@ def __init__(
reward_cos: bool = False,
device_str: Literal["cpu", "cuda"] = "cpu",
preprocessor_name: Literal["KHot", "OneHot", "Identity", "Enum"] = "KHot",
log_reward_clip: float = -100.,
log_reward_clip: float = -100.0,
):
"""HyperGrid environment from the GFlowNets paper.
The states are represented as 1-d tensors of length `ndim` with values in
Expand Down

0 comments on commit 77abec6

Please sign in to comment.