diff --git a/tutorials/examples/train_line.py b/tutorials/examples/train_line.py index f27af496..b68e5f4e 100644 --- a/tutorials/examples/train_line.py +++ b/tutorials/examples/train_line.py @@ -23,8 +23,8 @@ class Line(Env): def __init__( self, - mus: list = [-1, 1], - variances: list = [0.2, 0.2], + mus: list = [-2, 2], + variances: list = [0.5, 0.5], n_sd: float = 4.5, init_value: float = 0, n_steps_per_trajectory: int = 5, @@ -43,13 +43,11 @@ def __init__( self.init_value = init_value # Used in s0. self.lb = min(self.mus) - self.n_sd * max(self.sigmas) # Convienience only. self.ub = max(self.mus) + self.n_sd * max(self.sigmas) # Convienience only. - assert self.lb < self.init_value < self.ub # The state is [x_value, count]. x_value is initalized close to the lower bound. s0 = torch.tensor([self.init_value, 0.0], device=torch.device(device_str)) - sf = torch.FloatTensor([float("inf"), float("inf")], ).to(s0.device) - super().__init__(s0=s0, sf=sf) # Overwriting the default sf of -inf. + super().__init__(s0=s0) # sf is -inf. def make_States_class(self) -> type[States]: env = self @@ -57,15 +55,7 @@ def make_States_class(self) -> type[States]: class LineStates(States): state_shape: ClassVar[Tuple[int, ...]] = (2,) s0 = env.s0 # should be [init value, 0]. - sf = env.sf # should be [+inf, +inf]. - - @classmethod - def make_random_states_tensor(cls, batch_shape: Tuple[int, ...]) -> TT["batch_shape", 2, torch.float]: - # Scale [0, 1] values between lower & upper bound. - scaling = (self.ub - self.lb) + self.lb - x_val = torch.rand(batch_shape + (1,)) * scaling - steps = torch.full(batch_shape + (1,), self.n_steps_per_trajectory) - return torch.cat((x_val, steps), dim=-1, device=env.device) + sf = env.sf # should be [-inf, -inf]. return LineStates @@ -74,8 +64,8 @@ def make_Actions_class(self) -> type[Actions]: class LineActions(Actions): action_shape: ClassVar[Tuple[int, ...]] = (1,) # Does not include counter! - dummy_action: ClassVar[TT[2]] = torch.tensor([-float("inf")], device=env.device) - exit_action: ClassVar[TT[2]] = torch.tensor([float("inf")], device=env.device) + dummy_action: ClassVar[TT[2]] = torch.tensor([float("inf")], device=env.device) + exit_action: ClassVar[TT[2]] = torch.tensor([-float("inf")], device=env.device) return LineActions @@ -90,40 +80,26 @@ def maskless_backward_step(self, states: States, actions: Actions) -> TT["batch_ return states.tensor def is_action_valid(self, states: States, actions: Actions, backward: bool = False) -> bool: - """We are only going to prevent taking actions leftward beyond `S_0`.""" - non_exit_actions = actions[~actions.is_exit] - non_terminal_states = states[~actions.is_exit] - s0_states_idx = non_terminal_states.is_initial_state - # Can't take a backward step at the beginning of a trajectory. - if torch.any(s0_states_idx) and backward: + non_terminal_s0_states = states[~actions.is_exit].is_initial_state + if torch.any(non_terminal_s0_states) and backward: return False - non_s0_states = non_terminal_states[~s0_states_idx].tensor - non_s0_actions = non_exit_actions[~s0_states_idx].tensor - return True def reward(self, final_states: States) -> TT["batch_shape", torch.float]: - """Sum of the exponential of each log probability in the mixture.""" - r = torch.zeros(final_states.batch_shape) - for m in self.mixture: - r = r + torch.exp(m.log_prob(final_states.tensor[..., 0])) # x position. - - return r + return torch.exp(self.log_reward(final_states)) def log_reward(self, final_states: States) -> TT["batch_shape", torch.float]: - # TODO: Implement in base class. - log_rewards = [] - for m in self.mixture: - log_rewards.append(m.log_prob(final_states.tensor[..., 0])) - log_rewards = torch.stack(log_rewards, 1) - - if log_rewards.nelement() > 0: - return torch.logsumexp(log_rewards, -1) - else: + s = final_states.tensor[..., 0] + if s.nelement() == 0: return torch.zeros(final_states.batch_shape) + log_rewards = torch.empty((len(self.mixture),) + final_states.batch_shape) + for i, m in enumerate(self.mixture): + log_rewards[i] = m.log_prob(s) + + return torch.logsumexp(log_rewards, 0) @property def log_partition(self) -> float: @@ -196,7 +172,7 @@ def __init__( self.states = states self.n_steps = n_steps self.dist = Normal(mus, scales) - self.exit_action = torch.FloatTensor([float("inf")]).to(states.device) + self.exit_action = torch.FloatTensor([-float("inf")]).to(states.device) self.backward = backward def sample(self, sample_shape=()): @@ -204,30 +180,31 @@ def sample(self, sample_shape=()): # For any state which is at the terminal step, assign the exit action. if not self.backward: - exit_mask = torch.where( - self.states[..., 1].tensor >= self.n_steps, # This is the step counter. - torch.ones(sample_shape + (1,)), - torch.zeros(sample_shape + (1,)), - ).bool() + idx_at_final_step = self.states[..., 1].tensor == self.n_steps + exit_mask = torch.where(idx_at_final_step, 1, 0).bool() actions[exit_mask] = self.exit_action return actions def log_prob(self, sampled_actions): """TODO""" - # These are the exited states. - logprobs = torch.full_like(sampled_actions, fill_value=-float("inf")) + # The default value of logprobs is 0, because these represent the p=1 event + # of either the terminal forward (Sn->Sf) or backward (S1->S0) transition. + # We do not explicitly fill these values, but rather set the appropriate + # logprobs using the `exit_idx` mask. + logprobs = torch.full_like(sampled_actions, fill_value=0.0) + actions_to_eval = torch.full_like(sampled_actions, 0) # Used to remove infs. # TODO: Continous Timestamp Environmemt Subclass. if self.backward: - exit = self.states[..., 1].tensor == 1 # This is the backward exit action. - logprobs[~exit] = self.dist.log_prob(sampled_actions)[~exit] # This isn't efficient. - logprobs[exit] = 0 # log p(exit) == 1 is 0 + exit_idx = self.states[..., 1].tensor == 1 # This is the s1->s0 action. + actions_to_eval[~exit_idx] = sampled_actions[~exit_idx] + logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[~exit_idx] # TODO: inefficient! else: # Forward: handle exit actions. - exit = torch.all(sampled_actions == torch.full_like(sampled_actions[0], float("inf")), 1) # This is the exit action. - if sum(~exit) > 0: - logprobs[~exit] = self.dist.log_prob(sampled_actions)[~exit] # This isn't efficient. - logprobs[exit] = 0 # log p(exit) == 1 is 0 + exit_idx = torch.all(sampled_actions == torch.full_like(sampled_actions[0], -float("inf")), 1) # This is the exit action. + actions_to_eval[~exit_idx] = sampled_actions[~exit_idx] + if sum(~exit_idx) > 0: + logprobs[~exit_idx] = self.dist.log_prob(actions_to_eval)[~exit_idx] # TODO: inefficient! return logprobs.squeeze(-1) @@ -238,24 +215,16 @@ def __init__( self, hidden_dim: int, n_hidden_layers: int, - backward: bool, - s0_val: float, policy_std_min: float = 0.1, policy_std_max: float = 5, ): """Instantiates the neural network for the forward policy.""" assert policy_std_min > 0 assert policy_std_min < policy_std_max - - self.input_dim = 2 # [x_pos, counter] - self.output_dim = 2 # [mus, scales] - self.s0_val = s0_val - self.backward = backward self.policy_std_min = policy_std_min self.policy_std_max = policy_std_max - - if backward: - assert not math.isinf(s0_val) + self.input_dim = 2 # [x_pos, counter]. + self.output_dim = 2 # [mus, scales]. super().__init__( input_dim=self.input_dim, @@ -268,23 +237,8 @@ def __init__( def forward(self, preprocessed_states: TT["batch_shape", 2, float]) -> TT["batch_shape", "3"]: assert preprocessed_states.ndim == 2 out = super().forward(preprocessed_states) # [..., 2]: represents mean & std. - - # When forward, the mean can take any value. The variance must be > 0.1 minmax_norm = (self.policy_std_max - self.policy_std_min) out[..., 1] = torch.sigmoid(out[..., 1]) * minmax_norm + self.policy_std_min # Scales / Variances. - # print(torch.sum(out[..., 1] < self.policy_std_min)) - # print(torch.sum(out[..., 1] > self.policy_std_max)) - # print('--') - - # if self.backward: - # distance_to_s0 = preprocessed_states[..., 0] - self.s0_val - - # # At backward_step = 1, where the next step is s0, the only valid action - # # to to jump directly to s0. - # idx_to_s0 = preprocessed_states[..., 1] == 1 # s_1 -> s_0. - # if sum(idx_to_s0) > 0: - # out[idx_to_s0, 0] = distance_to_s0[idx_to_s0] - # #out[idx_to_s0, 1] = 1/(2*np.pi)**0.5 # Gaussian PDF scaling factor. return out @@ -318,26 +272,36 @@ def to_probability_distribution( ) -def get_scheduler(optim, n_iter, n_steps_scheduler=1500, scheduler_gamma=0.5): - return torch.optim.lr_scheduler.MultiStepLR( - optim, - milestones=[ - i * n_steps_scheduler - for i in range(1, 1 + int(n_iter / n_steps_scheduler)) - ], - gamma=scheduler_gamma, - ) +# def get_scheduler(optim, n_iter, n_steps_scheduler=1500, scheduler_gamma=0.5): +# return torch.optim.lr_scheduler.MultiStepLR( +# optim, +# milestones=[ +# i * n_steps_scheduler +# for i in range(1, 1 + int(n_iter / n_steps_scheduler)) +# ], +# gamma=scheduler_gamma, +# ) -def train(seed=4444, n_trajectories=3e6, batch_size=128, lr_base=1e-3, gradient_clip_value=10, n_logz_resets=0): - # Reproducibility. - torch.manual_seed(seed) - random.seed(seed) +def fix_seed(seed): + """Reproducibility.""" np.random.seed(seed) - torch.manual_seed(seed) - torch.backends.cudnn.deterministic = True + random.seed(seed) torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True torch.manual_seed(seed) + +def train( + gflownet, + seed=4444, + n_trajectories=3e6, + batch_size=128, + lr_base=1e-3, + gradient_clip_value=10, + n_logz_resets=0, + exploration_var_starting_val=2, + ): + fix_seed(seed) device_str = "cuda" if torch.cuda.is_available() else "cpu" n_iterations = int(n_trajectories // batch_size) logz_reset_interval = 50 @@ -353,59 +317,60 @@ def train(seed=4444, n_trajectories=3e6, batch_size=128, lr_base=1e-3, gradient_ # pb_module = BoxPBNeuralNet(hidden_dim, n_hidden_layers, n_components) # 3. Create the optimizer and scheduler. - optimizer = torch.optim.Adam(pf_module.parameters(), lr=lr_base) - logZ = dict(gflownet.named_parameters())["logZ"] - optimizer.add_param_group({"params": [logZ], "lr": lr_base * 10}) + optimizer = torch.optim.Adam(gflownet.pf_pb_parameters(), lr=lr_base) + lr_logZ = lr_base * 100 + optimizer.add_param_group({"params": gflownet.logz_parameters(), "lr": lr_logZ}) # TODO: # if not uniform_pb: # optimizer.add_param_group({"params": pb_module.parameters(), "lr": lr_base}) # optimizer.add_param_group({"params": logFmodule.parameters(), "lr": lr_logF}) - scheduler = get_scheduler( - optimizer, - n_iterations, - n_steps_scheduler=int(n_iterations / 2), - scheduler_gamma=0.5, - ) + # scheduler = get_scheduler( + # optimizer, + # n_iterations, + # n_steps_scheduler=int(n_iterations / 2), + # scheduler_gamma=0.5, + # ) # TODO: # 4. Sample from the true reward distribution, and fit a KDE to the samples. - n_val_samples = 1000 + # n_val_samples = 1000 # samples_from_reward = sample_from_reward(env, n_samples=n_val_samples) # true_kde = KernelDensity(kernel="exponential", bandwidth=0.1).fit( # samples_from_reward # ) # Training loop! - validation_interval = 1e4 + # validation_interval = 1e4 states_visited = 0 jsd = float("inf") tbar = trange(n_iterations, desc="Training iter") - scale_schedule = np.linspace(0.1, 0, n_iterations) + scale_schedule = np.linspace(exploration_var_starting_val, 0, n_iterations) for iteration in tbar: - if logz_reset_count < n_logz_resets and iteration % logz_reset_interval == 0: - gflownet.logZ = torch.nn.init.constant_(gflownet.logZ, 0) - print("resetting logz") - logz_reset_count += 1 + # if logz_reset_count < n_logz_resets and iteration % logz_reset_interval == 0: + # gflownet.logZ = torch.nn.init.constant_(gflownet.logZ, 0) + # print("resetting logz") + # logz_reset_count += 1 # Off Policy Sampling. - trajectories = gflownet.sample_trajectories( + trajectories, estimator_outputs = gflownet.sample_trajectories( env, n_samples=batch_size, + sample_off_policy=True, scale_factor=scale_schedule[iteration], ) - print(scale_schedule[iteration]) training_samples = gflownet.to_training_samples(trajectories) - optimizer.zero_grad() - loss = gflownet.loss(env, training_samples) + loss = gflownet.loss(env, training_samples, estimator_outputs=estimator_outputs) loss.backward() - - for p in gflownet.parameters(): - print(p.grad) + # print("{} / {} ".format( + # dict(gflownet.named_parameters())["pf.module.last_layer.bias"], + # dict(gflownet.named_parameters())["pb.module.last_layer.bias"], + # ) + # ) # # LESSON: Clipping # for p in gflownet.parameters(): @@ -413,10 +378,9 @@ def train(seed=4444, n_trajectories=3e6, batch_size=128, lr_base=1e-3, gradient_ # p.grad.data.clamp_(-gradient_clip_value, gradient_clip_value).nan_to_num_(0.0) optimizer.step() - scheduler.step() - + # scheduler.step() states_visited += len(trajectories) - assert logZ is not None + assert gflownet.logz_parameters()[0].item() is not None #to_log = {"loss": loss.item(), "states_visited": states_visited} # logZ_info = "" @@ -429,36 +393,36 @@ def train(seed=4444, n_trajectories=3e6, batch_size=128, lr_base=1e-3, gradient_ iteration, states_visited, loss.item(), - logZ.item(), + gflownet.logz_parameters()[0].item(), # Assumes only one estimate of logZ. env.log_partition, jsd, optimizer.param_groups[0]['lr'], ) ) - if iteration % validation_interval == 0: - validation_samples = gflownet.sample_terminating_states(env, n_val_samples) - # kde = KernelDensity(kernel="exponential", bandwidth=0.1).fit( - # validation_samples.tensor.detach().cpu().numpy() - # ) - # jsd = estimate_jsd(kde, true_kde) - #to_log.update({"JSD": jsd}) + # if iteration % validation_interval == 0: + # validation_samples = gflownet.sample_terminating_states(env, n_val_samples) + # kde = KernelDensity(kernel="exponential", bandwidth=0.1).fit( + # validation_samples.tensor.detach().cpu().numpy() + # ) + # jsd = estimate_jsd(kde, true_kde) + # to_log.update({"JSD": jsd}) - return jsd + return gflownet, jsd if __name__ == "__main__": - env = Line(mus=[-1, 1], variances=[0.2, 0.2], n_sd=4.5, init_value=0.5, n_steps_per_trajectory=5) + env = Line(mus=[-2, 2], variances=[0.5, 0.5], n_sd=4.5, init_value=0.5, n_steps_per_trajectory=10) # Forward and backward policy estimators. We pass the lower bound from the env here. - hid_dim = 32 + hid_dim = 128 + n_hidden_layers = 2 policy_std_min = 0.1 policy_std_max = 1 + exploration_var_starting_val = 2 pf_module = GaussianStepNeuralNet( hidden_dim=hid_dim, - n_hidden_layers=2, - backward=False, - s0_val=env.init_value, + n_hidden_layers=n_hidden_layers, policy_std_min=policy_std_min, policy_std_max=policy_std_max, ) @@ -470,9 +434,7 @@ def train(seed=4444, n_trajectories=3e6, batch_size=128, lr_base=1e-3, gradient_ pb_module = GaussianStepNeuralNet( hidden_dim=hid_dim, - n_hidden_layers=2, - backward=True, - s0_val=env.init_value, + n_hidden_layers=n_hidden_layers, policy_std_min=policy_std_min, policy_std_max=policy_std_max, ) @@ -490,6 +452,13 @@ def train(seed=4444, n_trajectories=3e6, batch_size=128, lr_base=1e-3, gradient_ ) # Magic hyperparameters: lr_base=4e-2, n_trajectories=3e6, batch_size=2048 - train(lr_base=1e-5, n_trajectories=1e6, batch_size=64) # I started training this with 1e-3 and then reduced it. + gflownet, jsd = train( + gflownet, + lr_base=1e-4, + n_trajectories=3e6, + batch_size=1024, + exploration_var_starting_val=exploration_var_starting_val + ) # I started training this with 1e-3 and then reduced it. + validation_samples = gflownet.sample_terminating_states(env, 10000) - render(env, validation_samples=validation_samples) + render(env, validation_samples=validation_samples) \ No newline at end of file