From 716ee7a453cc0da0606a0a8046117c6f7d34eb89 Mon Sep 17 00:00:00 2001 From: Joseph Date: Thu, 23 Nov 2023 11:52:22 -0500 Subject: [PATCH] log reward clipping removed --- src/gfn/gym/box.py | 3 +-- src/gfn/gym/discrete_ebm.py | 5 +---- src/gfn/gym/hypergrid.py | 3 --- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/src/gfn/gym/box.py b/src/gfn/gym/box.py index 5aa272a7..d5a899bd 100644 --- a/src/gfn/gym/box.py +++ b/src/gfn/gym/box.py @@ -20,7 +20,6 @@ def __init__( R2: float = 2.0, epsilon: float = 1e-4, device_str: Literal["cpu", "cuda"] = "cpu", - log_reward_clip: float = -100.0, ): assert 0 < delta <= 1, "delta must be in (0, 1]" self.delta = delta @@ -31,7 +30,7 @@ def __init__( self.R1 = R1 self.R2 = R2 - super().__init__(s0=s0, log_reward_clip=log_reward_clip) + super().__init__(s0=s0) def make_States_class(self) -> type[States]: env = self diff --git a/src/gfn/gym/discrete_ebm.py b/src/gfn/gym/discrete_ebm.py index ea73b336..9c880146 100644 --- a/src/gfn/gym/discrete_ebm.py +++ b/src/gfn/gym/discrete_ebm.py @@ -48,7 +48,6 @@ def __init__( alpha: float = 1.0, device_str: Literal["cpu", "cuda"] = "cpu", preprocessor_name: Literal["Identity", "Enum"] = "Identity", - log_reward_clip: float = -100.0, ): """Discrete EBM environment. @@ -60,7 +59,6 @@ def __init__( device_str: "cpu" or "cuda". Defaults to "cpu". preprocessor_name: "KHot" or "OneHot" or "Identity". Defaults to "KHot". - log_reward_clip: Minimum log reward allowable (namely, for log(0)). """ self.ndim = ndim @@ -94,7 +92,6 @@ def __init__( sf=sf, device_str=device_str, preprocessor=preprocessor, - log_reward_clip=log_reward_clip, ) def make_States_class(self) -> type[DiscreteStates]: @@ -195,7 +192,7 @@ def log_reward(self, final_states: DiscreteStates) -> TT["batch_shape"]: canonical = 2 * raw_states - 1 log_reward = -self.alpha * self.energy(canonical) - return log_reward.clip(self.log_reward_clip) + return log_reward def get_states_indices(self, states: DiscreteStates) -> TT["batch_shape"]: """The chosen encoding is the following: -1 -> 0, 0 -> 1, 1 -> 2, then we convert to base 3""" diff --git a/src/gfn/gym/hypergrid.py b/src/gfn/gym/hypergrid.py index d54c8a76..2c4c2859 100644 --- a/src/gfn/gym/hypergrid.py +++ b/src/gfn/gym/hypergrid.py @@ -25,7 +25,6 @@ def __init__( reward_cos: bool = False, device_str: Literal["cpu", "cuda"] = "cpu", preprocessor_name: Literal["KHot", "OneHot", "Identity", "Enum"] = "KHot", - log_reward_clip: float = -100.0, ): """HyperGrid environment from the GFlowNets paper. The states are represented as 1-d tensors of length `ndim` with values in @@ -42,7 +41,6 @@ def __init__( reward_cos (bool, optional): Which version of the reward to use. Defaults to False. device_str (str, optional): "cpu" or "cuda". Defaults to "cpu". preprocessor_name (str, optional): "KHot" or "OneHot" or "Identity". Defaults to "KHot". - log_reward_clip: Minimum log reward allowable (namely, for log(0)). """ self.ndim = ndim self.height = height @@ -82,7 +80,6 @@ def __init__( sf=sf, device_str=device_str, preprocessor=preprocessor, - log_reward_clip=log_reward_clip, ) def make_States_class(self) -> type[DiscreteStates]: