Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Nov 21, 2023
1 parent e052c82 commit f897aab
Showing 1 changed file with 32 additions and 22 deletions.
54 changes: 32 additions & 22 deletions tutorials/examples/train_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ class Line(Env):

def __init__(
self,
mus: list = [-2, 2],
variances: list = [0.5, 0.5],
mus: list,
variances: list,
init_value: float,
n_sd: float = 4.5,
init_value: float = 0,
n_steps_per_trajectory: int = 5,
device_str: Literal["cpu", "cuda"] = "cpu",
):
Expand All @@ -37,15 +37,14 @@ def __init__(
self.n_sd = n_sd
self.n_steps_per_trajectory = n_steps_per_trajectory
self.mixture = [
Normal(torch.tensor(m), torch.tensor(s)) for m, s in zip(mus, self.sigmas)
Normal(m, s) for m, s in zip(self.mus, self.sigmas)
]

self.init_value = init_value # Used in s0.
self.lb = min(self.mus) - self.n_sd * max(self.sigmas) # Convienience only.
self.ub = max(self.mus) + self.n_sd * max(self.sigmas) # Convienience only.
assert self.lb < self.init_value < self.ub

# The state is [x_value, count]. x_value is initalized close to the lower bound.
s0 = torch.tensor([self.init_value, 0.0], device=torch.device(device_str))
super().__init__(s0=s0) # sf is -inf.

Expand All @@ -54,7 +53,7 @@ def make_States_class(self) -> type[States]:

class LineStates(States):
state_shape: ClassVar[Tuple[int, ...]] = (2,)
s0 = env.s0 # should be [init value, 0].
s0 = env.s0 # should be [init x value, 0].
sf = env.sf # should be [-inf, -inf].

return LineStates
Expand All @@ -81,14 +80,21 @@ def maskless_backward_step(self, states: States, actions: Actions) -> TT["batch_

def is_action_valid(self, states: States, actions: Actions, backward: bool = False) -> bool:
# Can't take a backward step at the beginning of a trajectory.
non_terminal_s0_states = states[~actions.is_exit].is_initial_state
if torch.any(non_terminal_s0_states) and backward:
if torch.any(states[~actions.is_exit].is_initial_state) and backward:
return False

return True

def reward(self, final_states: States) -> TT["batch_shape", torch.float]:
return torch.exp(self.log_reward(final_states))
# def reward(self, final_states: States) -> TT["batch_shape", torch.float]:
# s = final_states.tensor[..., 0]
# if s.nelement() == 0:
# return torch.zeros(final_states.batch_shape)

# rewards = torch.empty(final_states.batch_shape)
# for i, m in enumerate(self.mixture):
# rewards = rewards + torch.exp(m.log_prob(s))

# return rewards

def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]:
s = final_states.tensor[..., 0]
Expand All @@ -115,9 +121,13 @@ def render(env, validation_samples=None):
100,
)

d = np.zeros(x.shape)
for mu, sigma in zip(env.mus, env.sigmas):
d += stats.norm.pdf(x, mu, sigma)
# Get the rewards from our environment.
r = env.States(
torch.tensor(
np.stack((x, torch.ones(len(x))), 1) # Add dummy counter.
)
)
d = torch.exp(env.log_reward(r)) # Plots the reward, not the log reward.

fig, ax1 = plt.subplots()

Expand Down Expand Up @@ -216,7 +226,7 @@ def __init__(
hidden_dim: int,
n_hidden_layers: int,
policy_std_min: float = 0.1,
policy_std_max: float = 5,
policy_std_max: float = 1,
):
"""Instantiates the neural network for the forward policy."""
assert policy_std_min > 0
Expand Down Expand Up @@ -256,7 +266,7 @@ def to_probability_distribution(
self,
states: States,
module_output: TT["batch_shape", "output_dim", float],
scale_factor = 1, # policy_kwarg.
scale_factor = 0, # policy_kwarg.
) -> Distribution:
# First, we verify that the batch shape of states is 1
assert len(states.batch_shape) == 1
Expand Down Expand Up @@ -360,7 +370,7 @@ def train(
env,
n_samples=batch_size,
sample_off_policy=True,
scale_factor=scale_schedule[iteration],
scale_factor=scale_schedule[iteration], # Off policy kwargs.
)
training_samples = gflownet.to_training_samples(trajectories)
optimizer.zero_grad()
Expand Down Expand Up @@ -412,10 +422,10 @@ def train(

if __name__ == "__main__":

env = Line(mus=[-2, 2], variances=[0.5, 0.5], n_sd=4.5, init_value=0.5, n_steps_per_trajectory=10)
env = Line(mus=[2, 5], variances=[0.2, 0.2], init_value=0, n_sd=4.5, n_steps_per_trajectory=5)
# Forward and backward policy estimators. We pass the lower bound from the env here.
hid_dim = 128
n_hidden_layers = 2
hid_dim = 64
n_hidden_layers = 1
policy_std_min = 0.1
policy_std_max = 1
exploration_var_starting_val = 2
Expand Down Expand Up @@ -454,9 +464,9 @@ def train(
# Magic hyperparameters: lr_base=4e-2, n_trajectories=3e6, batch_size=2048
gflownet, jsd = train(
gflownet,
lr_base=1e-4,
n_trajectories=3e6,
batch_size=1024,
lr_base=1e-3,
n_trajectories=1e6,
batch_size=256,
exploration_var_starting_val=exploration_var_starting_val
) # I started training this with 1e-3 and then reduced it.

Expand Down

0 comments on commit f897aab

Please sign in to comment.