Skip to content

Commit

Permalink
log reward clipping removed
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 23, 2023
1 parent 6aab6e0 commit 716ee7a
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 9 deletions.
3 changes: 1 addition & 2 deletions src/gfn/gym/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(
R2: float = 2.0,
epsilon: float = 1e-4,
device_str: Literal["cpu", "cuda"] = "cpu",
log_reward_clip: float = -100.0,
):
assert 0 < delta <= 1, "delta must be in (0, 1]"
self.delta = delta
Expand All @@ -31,7 +30,7 @@ def __init__(
self.R1 = R1
self.R2 = R2

super().__init__(s0=s0, log_reward_clip=log_reward_clip)
super().__init__(s0=s0)

def make_States_class(self) -> type[States]:
env = self
Expand Down
5 changes: 1 addition & 4 deletions src/gfn/gym/discrete_ebm.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __init__(
alpha: float = 1.0,
device_str: Literal["cpu", "cuda"] = "cpu",
preprocessor_name: Literal["Identity", "Enum"] = "Identity",
log_reward_clip: float = -100.0,
):
"""Discrete EBM environment.
Expand All @@ -60,7 +59,6 @@ def __init__(
device_str: "cpu" or "cuda". Defaults to "cpu".
preprocessor_name: "KHot" or "OneHot" or "Identity".
Defaults to "KHot".
log_reward_clip: Minimum log reward allowable (namely, for log(0)).
"""
self.ndim = ndim

Expand Down Expand Up @@ -94,7 +92,6 @@ def __init__(
sf=sf,
device_str=device_str,
preprocessor=preprocessor,
log_reward_clip=log_reward_clip,
)

def make_States_class(self) -> type[DiscreteStates]:
Expand Down Expand Up @@ -195,7 +192,7 @@ def log_reward(self, final_states: DiscreteStates) -> TT["batch_shape"]:
canonical = 2 * raw_states - 1
log_reward = -self.alpha * self.energy(canonical)

return log_reward.clip(self.log_reward_clip)
return log_reward

def get_states_indices(self, states: DiscreteStates) -> TT["batch_shape"]:
"""The chosen encoding is the following: -1 -> 0, 0 -> 1, 1 -> 2, then we convert to base 3"""
Expand Down
3 changes: 0 additions & 3 deletions src/gfn/gym/hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(
reward_cos: bool = False,
device_str: Literal["cpu", "cuda"] = "cpu",
preprocessor_name: Literal["KHot", "OneHot", "Identity", "Enum"] = "KHot",
log_reward_clip: float = -100.0,
):
"""HyperGrid environment from the GFlowNets paper.
The states are represented as 1-d tensors of length `ndim` with values in
Expand All @@ -42,7 +41,6 @@ def __init__(
reward_cos (bool, optional): Which version of the reward to use. Defaults to False.
device_str (str, optional): "cpu" or "cuda". Defaults to "cpu".
preprocessor_name (str, optional): "KHot" or "OneHot" or "Identity". Defaults to "KHot".
log_reward_clip: Minimum log reward allowable (namely, for log(0)).
"""
self.ndim = ndim
self.height = height
Expand Down Expand Up @@ -82,7 +80,6 @@ def __init__(
sf=sf,
device_str=device_str,
preprocessor=preprocessor,
log_reward_clip=log_reward_clip,
)

def make_States_class(self) -> type[DiscreteStates]:
Expand Down

0 comments on commit 716ee7a

Please sign in to comment.