From 77abec6a03d3d9c7935f8e4d0482d8a57321f44a Mon Sep 17 00:00:00 2001 From: Salem Date: Wed, 25 Oct 2023 14:30:04 -0400 Subject: [PATCH] black and autoflake reformatting --- src/gfn/env.py | 2 +- src/gfn/gym/box.py | 2 +- src/gfn/gym/discrete_ebm.py | 4 ++-- src/gfn/gym/hypergrid.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gfn/env.py b/src/gfn/env.py index c16d2018..90ddd240 100644 --- a/src/gfn/env.py +++ b/src/gfn/env.py @@ -23,7 +23,7 @@ def __init__( sf: Optional[TT["state_shape", torch.float]] = None, device_str: Optional[str] = None, preprocessor: Optional[Preprocessor] = None, - log_reward_clip: Optional[float] = -100., + log_reward_clip: Optional[float] = -100.0, ): """Initializes an environment. diff --git a/src/gfn/gym/box.py b/src/gfn/gym/box.py index 28eb0893..5aa272a7 100644 --- a/src/gfn/gym/box.py +++ b/src/gfn/gym/box.py @@ -20,7 +20,7 @@ def __init__( R2: float = 2.0, epsilon: float = 1e-4, device_str: Literal["cpu", "cuda"] = "cpu", - log_reward_clip: float = -100., + log_reward_clip: float = -100.0, ): assert 0 < delta <= 1, "delta must be in (0, 1]" self.delta = delta diff --git a/src/gfn/gym/discrete_ebm.py b/src/gfn/gym/discrete_ebm.py index 7839c568..ea73b336 100644 --- a/src/gfn/gym/discrete_ebm.py +++ b/src/gfn/gym/discrete_ebm.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import ClassVar, Literal, Tuple, cast +from typing import ClassVar, Literal, Tuple import torch import torch.nn as nn @@ -48,7 +48,7 @@ def __init__( alpha: float = 1.0, device_str: Literal["cpu", "cuda"] = "cpu", preprocessor_name: Literal["Identity", "Enum"] = "Identity", - log_reward_clip: float = -100., + log_reward_clip: float = -100.0, ): """Discrete EBM environment. diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index 0ddfc2c9..d54c8a76 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -1,7 +1,7 @@ """ Copied and Adapted from https://github.com/Tikquuss/GflowNets_Tutorial """ -from typing import ClassVar, Literal, Tuple, cast +from typing import ClassVar, Literal, Tuple import torch from einops import rearrange @@ -25,7 +25,7 @@ def __init__( reward_cos: bool = False, device_str: Literal["cpu", "cuda"] = "cpu", preprocessor_name: Literal["KHot", "OneHot", "Identity", "Enum"] = "KHot", - log_reward_clip: float = -100., + log_reward_clip: float = -100.0, ): """HyperGrid environment from the GFlowNets paper. The states are represented as 1-d tensors of length `ndim` with values in