From 8b86b5fecdcc2dd52f90ab0ae18df23d7b77717f Mon Sep 17 00:00:00 2001 From: Alexander Koch Date: Wed, 28 Oct 2020 09:17:03 +0100 Subject: [PATCH] Add V-MPO --- examples/mujoco_vmpo.py | 58 ++++++ rlpyt/agents/pg/gaussian_vmpo_agent.py | 67 +++++++ rlpyt/algos/pg/v_mpo.py | 245 +++++++++++++++++++++++++ rlpyt/models/pg/mujoco_ff_model.py | 54 ++++++ rlpyt/models/popart_normalization.py | 105 +++++++++++ rlpyt/utils/tensor.py | 15 ++ 6 files changed, 544 insertions(+) create mode 100644 examples/mujoco_vmpo.py create mode 100644 rlpyt/agents/pg/gaussian_vmpo_agent.py create mode 100644 rlpyt/algos/pg/v_mpo.py create mode 100644 rlpyt/models/popart_normalization.py diff --git a/examples/mujoco_vmpo.py b/examples/mujoco_vmpo.py new file mode 100644 index 00000000..5f7ba180 --- /dev/null +++ b/examples/mujoco_vmpo.py @@ -0,0 +1,58 @@ + +""" +Runs one instance of the Mujoco environment and optimizes using V-MPO algorithm. +""" + +from rlpyt.samplers.serial.sampler import SerialSampler +from rlpyt.runners.minibatch_rl import MinibatchRlEval +from rlpyt.utils.logging.context import logger_context +from rlpyt.algos.pg.v_mpo import VMPO +from rlpyt.agents.pg.gaussian_vmpo_agent import MujocoVmpoAgent +from rlpyt.models.pg.mujoco_ff_model import MujocoVmpoFfModel +from rlpyt.envs.gym import make as gym_make +from rlpyt.samplers.parallel.cpu.sampler import CpuSampler +from rlpyt.utils.launching.affinity import make_affinity + + +def build_and_train(id="Ant-v3", run_ID=0, cuda_idx=None): + affinity = make_affinity(n_cpu_core=24, cpu_per_run=24, n_gpu=0, set_affinity=True) + sampler = CpuSampler( + EnvCls=gym_make, + env_kwargs=dict(id=id), + eval_env_kwargs=dict(id=id), + batch_T=40, # Four time-steps per sampler iteration. + batch_B=64 * 100, + max_decorrelation_steps=100, + eval_n_envs=1, + eval_max_steps=int(10e8), + eval_max_trajectories=8, + ) + algo = VMPO(T_target_steps=100, pop_art_reward_normalization=True, discrete_actions=False, epochs=1) + agent = MujocoVmpoAgent(ModelCls=MujocoVmpoFfModel, model_kwargs=dict(linear_value_output=False, layer_norm=True)) + runner = MinibatchRlEval( + algo=algo, + agent=agent, + sampler=sampler, + n_steps=int(1e10), + log_interval_steps=int(1e6), + affinity=affinity + ) + config = dict(id=id) + name = "vmpo_" + id + log_dir = "vmpo_mujoco" + with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"): + runner.train() + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('--id', help='gym env id', default='Ant-v3') + parser.add_argument('--run_ID', help='run identifier (logging)', type=int, default=0) + parser.add_argument('--cuda_idx', help='gpu to use ', type=int, default=None) + args = parser.parse_args() + build_and_train( + id=args.id, + run_ID=args.run_ID, + cuda_idx=args.cuda_idx, + ) diff --git a/rlpyt/agents/pg/gaussian_vmpo_agent.py b/rlpyt/agents/pg/gaussian_vmpo_agent.py new file mode 100644 index 00000000..d79470be --- /dev/null +++ b/rlpyt/agents/pg/gaussian_vmpo_agent.py @@ -0,0 +1,67 @@ +from rlpyt.distributions.gaussian import Gaussian +import torch +from rlpyt.agents.pg.base import AgentInfo, AgentInfoRnn +from rlpyt.utils.buffer import buffer_to, buffer_func, buffer_method +from rlpyt.agents.pg.mujoco import MujocoMixin, AlternatingRecurrentGaussianPgAgent +from rlpyt.agents.base import (AgentStep, BaseAgent, RecurrentAgentMixin, + AlternatingRecurrentAgentMixin) +from rlpyt.utils.collections import namedarraytuple +from rlpyt.distributions.gaussian import DistInfoStd + +DistInfo = namedarraytuple("DistInfo", ["mean", 'std']) + + +class VmpoAgent(RecurrentAgentMixin, BaseAgent): + """ + Base class for gaussian vmpo agents. This version uses a Gaussian with diagonal covariance matrix. It expects a + mu and std vector from the agent. The agent should have applied a softmax to the std to ensure positive values. + Exp on the std output seems to workd slightly worse. + """ + + def initialize(self, env_spaces, *args, **kwargs): + """Extends base method to build Gaussian distribution.""" + super().initialize(env_spaces, *args, **kwargs) + self.distribution = Gaussian( + dim=env_spaces.action.shape[0], + # squash=env_spaces.action.high[0], + # min_std=MIN_STD, + # clip=env_spaces.action.high[0], # Probably +1? + ) + + def __call__(self, observation, prev_action, prev_reward, init_rnn_state): + """Performs forward pass on training data, for algorithm.""" + model_inputs = buffer_to((observation, prev_action, prev_reward, init_rnn_state), + device=self.device) + mu, std, value, rnn_state = self.model(*model_inputs) + dist_info, value = buffer_to((DistInfoStd(mean=mu, log_std=std), value), device="cpu") + return dist_info, value, rnn_state + + @torch.no_grad() + def step(self, observation, prev_action, prev_reward): + agent_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) + mu, std, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state) + dist_info = DistInfoStd(mean=mu, log_std=std) + # action = self.distribution.sample(dist_info) if self._mode == 'sample' else mu + dist = torch.distributions.normal.Normal(loc=mu, scale=std) + action = dist.sample() if self._mode == 'sample' else mu + if self.prev_rnn_state is None: + prev_rnn_state = buffer_func(rnn_state, torch.zeros_like) + else: + prev_rnn_state = self.prev_rnn_state + + # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage. + # (Special case: model should always leave B dimension in.) + prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1) + agent_info = AgentInfoRnn(dist_info=dist_info, value=value, + prev_rnn_state=prev_rnn_state) + action, agent_info = buffer_to((action, agent_info), device="cpu") + self.advance_rnn_state(rnn_state) # Keep on device. + return AgentStep(action=action, agent_info=agent_info) + + +class MujocoVmpoAgent(MujocoMixin, VmpoAgent): + pass + + +class AlternatingVmpoAgent(AlternatingRecurrentAgentMixin, MujocoVmpoAgent): + pass diff --git a/rlpyt/algos/pg/v_mpo.py b/rlpyt/algos/pg/v_mpo.py new file mode 100644 index 00000000..66ebe2d5 --- /dev/null +++ b/rlpyt/algos/pg/v_mpo.py @@ -0,0 +1,245 @@ +import torch +import numpy as np +from rlpyt.utils.tensor import batched_trace, batched_quadratic_form +from rlpyt.algos.base import RlAlgorithm +from rlpyt.models.popart_normalization import PopArtLayer +from rlpyt.algos.pg.base import OptInfo +from rlpyt.agents.base import AgentInputs +from rlpyt.utils.tensor import valid_mean +from rlpyt.utils.quick_args import save__init__args +from rlpyt.utils.buffer import buffer_to +from rlpyt.utils.collections import namedarraytuple, namedtuple +from rlpyt.utils.misc import iterate_mb_idxs +from rlpyt.algos.utils import (discount_return, generalized_advantage_estimation, valid_from_done) + +OptInfo = namedtuple("OptInfo", ["loss", 'pi_loss', 'eta_loss', 'alpha_loss', 'value_loss', + 'alpha_mu_loss', 'alpha_sigma_loss', + 'mu_kl', 'sigma_kl', 'advantage', 'normalized_return', + 'alpha', 'eta', 'alpha_mu', 'alpha_sigma', + 'pi_mu', 'pi_log_std', 'policy_kl', + "entropy", "perplexity"]) + +class VMPO(RlAlgorithm): + opt_info_fields = tuple(f for f in OptInfo._fields) # copy + bootstrap_value = False # Tells the sampler it needs Value(State') + + def __init__( + self, + discount=0.99, + learning_rate=1e-4, + T_target_steps=100, + bootstrap_with_online_model=False, + OptimCls=torch.optim.Adam, + pop_art_reward_normalization=True, + optim_kwargs=None, + initial_optim_state_dict=None, + epochs=1, + discrete_actions=False, + epsilon_eta=0.01, + epsilon_alpha=0.01, + initial_eta=1.0, + initial_alpha=5.0, + initial_alpha_mu=1.0, + initial_alpha_sigma=1.0, + epsilon_alpha_mu=0.0075, + epsilon_alpha_sigma=1e-5, + ): + """Saves input settings.""" + if optim_kwargs is None: + optim_kwargs = dict() + self.pop_art_normalizer = PopArtLayer() + save__init__args(locals()) + + def initialize(self, agent, n_itr, batch_spec, mid_batch_reset=False, + examples=None, world_size=1, rank=0): + self.agent = agent + self.alpha = torch.autograd.Variable(torch.ones(1) * self.initial_alpha, requires_grad=True) + self.alpha_mu = torch.autograd.Variable(torch.ones(1) * self.initial_alpha_mu, requires_grad=True) + self.alpha_sigma = torch.autograd.Variable(torch.ones(1) * self.initial_alpha_sigma, requires_grad=True) + self.eta = torch.autograd.Variable(torch.ones(1) * self.initial_eta, requires_grad=True) + + self.optimizer = self.OptimCls(list(self.agent.parameters()) + + list(self.pop_art_normalizer.parameters()) + + [self.alpha, self.alpha_mu, self.alpha_sigma, self.eta], + lr=self.learning_rate, **self.optim_kwargs) + if self.initial_optim_state_dict is not None: + self.load_optim_state_dict(self.initial_optim_state_dict) + self.n_itr = n_itr + self.batch_spec = batch_spec + self.mid_batch_reset = mid_batch_reset + self.rank = rank + self.world_size = world_size + self._batch_size = self.batch_spec.size // self.T_target_steps # For logging. + + def process_returns(self, reward, done, value_prediction, + action, dist_info, old_dist_info, opt_info): + done = done.type(reward.dtype) + if self.pop_art_reward_normalization: + unnormalized_value = value_prediction + value_prediction, normalized_value = self.pop_art_normalizer(value_prediction) + + bootstrap_value = value_prediction[-1] + reward, value_prediction, done = reward[:-1], value_prediction[:-1], done[:-1] + + return_ = discount_return(reward, done, bootstrap_value.detach(), self.discount) + if self.pop_art_reward_normalization: + self.pop_art_normalizer.update_parameters(return_.unsqueeze(-1), + torch.ones_like(return_.unsqueeze(-1))) + _, normalized_value = self.pop_art_normalizer(unnormalized_value[:-1]) + return_ = self.pop_art_normalizer.normalize(return_) + advantage = return_ - normalized_value.detach() + value_prediction = normalized_value + opt_info.normalized_return.append(return_.numpy()) + else: + advantage = return_ - value_prediction.detach() + + valid = valid_from_done(done) # Recurrent: no reset during training. + opt_info.advantage.append(advantage.numpy()) + + loss, opt_info = self.loss(dist_info=dist_info[:-1], + value=value_prediction, + action=action[:-1], + return_=return_, + advantage=advantage.detach(), + valid=valid, + old_dist_info=old_dist_info[:-1], + opt_info=opt_info) + return loss, opt_info + + def optimize_agent(self, itr, samples=None, sampler_itr=None): + """ + Train the agent, for multiple epochs over minibatches taken from the + input samples. Organizes agent inputs from the training data, and + moves them to device (e.g. GPU) up front, so that minibatches are + formed within device, without further data transfer. + """ + opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) + agent_inputs = AgentInputs( # Move inputs to device once, index there. + observation=samples.env.observation, + prev_action=samples.agent.prev_action, + prev_reward=samples.env.prev_reward, + ) + agent_inputs = buffer_to(agent_inputs, device=self.agent.device) + init_rnn_state = buffer_to(samples.agent.agent_info.prev_rnn_state[0], device=self.agent.device) + T, B = samples.env.reward.shape[:2] + mb_size = B // self.T_target_steps + for _ in range(self.epochs): + for idxs in iterate_mb_idxs(B, mb_size, shuffle=True): + self.optimizer.zero_grad() + dist_info, value, _ = self.agent(*agent_inputs[:, idxs], init_rnn_state[idxs]) + loss, opt_info = self.process_returns(samples.env.reward[:, idxs], + done=samples.env.done[:, idxs], + value_prediction=value.cpu(), + action=samples.agent.action[:, idxs], + dist_info=dist_info, + old_dist_info=samples.agent.agent_info.dist_info[:, idxs], + opt_info=opt_info) + loss.backward() + self.optimizer.step() + self.clamp_lagrange_multipliers() + opt_info.loss.append(loss.item()) + self.update_counter += 1 + return opt_info + + def loss(self, dist_info, value, action, return_, advantage, valid, old_dist_info, opt_info): + T, B = tuple(action.shape[:2]) + advantage = advantage.clamp_max(60) # clamp due to numerical instabilities in logsumexp + num_valid = valid.sum().type(torch.int32) + # map advantages to positve values so that topk with valid mask gives top 50 % of valid advantages + top_advantages, top_advantages_indeces = torch.topk((valid * torch.exp(advantage / 100)).reshape(T * B), + num_valid // 2) + top_advantages = torch.log(top_advantages) * 100 + advantage_mask = torch.zeros_like(advantage.view(T * B)) + advantage_mask[top_advantages_indeces] = 1 + advantage_mask = advantage_mask.reshape(T, B) + + log_advantage_sum = torch.logsumexp(top_advantages / self.eta, dim=0) + phi = torch.exp((advantage/self.eta) - log_advantage_sum) + value_error = 0.5 * (value - return_) ** 2 + value_loss = valid_mean(value_error, valid) + eta_loss = self.eta * self.epsilon_eta + self.eta * (log_advantage_sum - torch.log(0.5 * num_valid)) + + if self.discrete_actions: + pi_loss, alpha_loss, opt_info = self.discrete_actions_loss(advantage_mask, phi, action, dist_info, + old_dist_info, opt_info) + loss = pi_loss + value_loss + eta_loss + alpha_loss + else: + pi_loss, alpha_loss, opt_info = self.continuous_actions_loss(advantage_mask, phi, action, dist_info, + old_dist_info, valid, opt_info) + loss = pi_loss + value_loss + eta_loss + alpha_loss + + opt_info.pi_loss.append(pi_loss.item()) + opt_info.eta_loss.append(eta_loss.item()) + opt_info.value_loss.append(value_loss.item()) + opt_info.eta.append(self.eta.item()) + return loss, opt_info + + def discrete_actions_loss(self, advantage_mask, phi, action, dist_info, old_dist_info, opt_info): + dist = self.agent.distribution + pi_loss = - torch.sum(advantage_mask * (phi.detach() * dist.log_likelihood(action.contiguous(), dist_info))) + policy_kl = dist.kl(old_dist_info, dist_info) + alpha_loss = valid_mean( + self.alpha * (self.epsilon_alpha - policy_kl.detach()) + self.alpha.detach() * policy_kl) + opt_info.alpha_loss.append(alpha_loss.item()) + opt_info.alpha.append(self.alpha.item()) + opt_info.policy_kl.append(policy_kl.mean().item()) + return pi_loss, alpha_loss, opt_info + + def continuous_actions_loss(self, advantage_mask, phi, action, dist_info, old_dist_info, valid, opt_info): + d = np.prod(action.shape[-1]) + distribution = torch.distributions.normal.Normal(loc=dist_info.mean, scale=dist_info.log_std) + pi_loss = - torch.sum(advantage_mask * (phi.detach() * distribution.log_prob(action).sum(dim=-1))) + # pi_loss = - torch.sum(advantage_mask * (phi.detach() * self.agent.distribution.log_likelihood(action, dist_info))) + new_std = dist_info.log_std + old_std = old_dist_info.log_std + old_covariance = torch.diag_embed(old_std) + old_covariance_inverse = torch.diag_embed(1 / old_std) + new_covariance_inverse = torch.diag_embed(1 / new_std) + old_covariance_determinant = torch.prod(old_std, dim=-1) + new_covariance_determinant = torch.prod(new_std, dim=-1) + + mu_kl = 0.5 * batched_quadratic_form(dist_info.mean - old_dist_info.mean, old_covariance_inverse) + trace = batched_trace(torch.matmul(new_covariance_inverse, old_covariance)) + sigma_kl = 0.5 * (trace - d + torch.log(new_covariance_determinant / old_covariance_determinant)) + alpha_mu_loss = valid_mean( + self.alpha_mu * (self.epsilon_alpha_mu - mu_kl.detach()) + self.alpha_mu.detach() * mu_kl, valid) + alpha_sigma_loss = valid_mean(self.alpha_sigma * ( + self.epsilon_alpha_sigma - sigma_kl.detach()) + self.alpha_sigma.detach() * sigma_kl, valid) + opt_info.alpha_mu.append(self.alpha_mu.item()) + opt_info.alpha_sigma.append(self.alpha_sigma.item()) + opt_info.alpha_mu_loss.append(alpha_mu_loss.item()) + opt_info.mu_kl.append(valid_mean(mu_kl, valid).item()) + opt_info.sigma_kl.append(valid_mean(sigma_kl, valid).item()) + opt_info.alpha_sigma_loss.append(valid_mean(self.epsilon_alpha_sigma - sigma_kl, valid).item()) + opt_info.pi_mu.append(dist_info.mean.mean().item()) + opt_info.pi_log_std.append(dist_info.log_std.mean().item()) + return pi_loss, alpha_mu_loss + alpha_sigma_loss, opt_info + + def clamp_lagrange_multipliers(self): + """ + As described in the paper alpha and eta are lagrange multipliers that must be positive. That's why + they are clamped after every update + """ + with torch.no_grad(): + self.alpha.clamp_min_(1e-8) + self.alpha_mu.clamp_min_(1e-8) + self.alpha_sigma.clamp_min_(1e-8) + self.eta.clamp_min_(1e-8) + + def optim_state_dict(self): + return dict( + optimizer=self.optimizer.state_dict(), + eta=self.eta, + alpha=self.alpha, + alpha_mu=self.alpha_mu, + alpha_sigma=self.alpha_sigma, + pop_art_layer=self.pop_art_normalizer.state_dict() + ) + + def load_optim_state_dict(self, state_dict): + self.optimizer.load_state_dict(state_dict["optimizer"]) + self.pop_art_normalizer.load_state_dict(state_dict['pop_art_layer']) + self.eta.data = state_dict['eta'] + self.alpha.data = state_dict['alpha'] + self.alpha_mu.data = state_dict['alpha_mu'] + self.alpha_sigma.data = state_dict['alpha_sigma'] diff --git a/rlpyt/models/pg/mujoco_ff_model.py b/rlpyt/models/pg/mujoco_ff_model.py index 710ef837..8ac09856 100644 --- a/rlpyt/models/pg/mujoco_ff_model.py +++ b/rlpyt/models/pg/mujoco_ff_model.py @@ -85,3 +85,57 @@ def forward(self, observation, prev_action, prev_reward): def update_obs_rms(self, observation): if self.normalize_observation: self.obs_rms.update(observation) + + + + +class MujocoVmpoFfModel(torch.nn.Module): + """ + Model commonly used in Mujoco locomotion agents: an MLP which outputs + distribution means, separate parameter for learned log_std, and separate + MLP for state-value estimate. + """ + + def __init__( + self, + observation_shape, + action_size, + linear_value_output=True, + layer_norm=False + ): + """Instantiate neural net modules according to inputs.""" + super().__init__() + self._obs_ndim = len(observation_shape) + input_size = int(np.prod(observation_shape)) + self.action_size = action_size + self.layer_norm = torch.nn.LayerNorm(input_size) if layer_norm else None + self.mu_mlp = MlpModel( + input_size=input_size, + hidden_sizes=[512, 256, 256], + output_size=2 * action_size, + ) + list(self.mu_mlp.parameters())[-1].data = list(self.mu_mlp.parameters())[-1].data / 100 + list(self.mu_mlp.parameters())[-2].data = list(self.mu_mlp.parameters())[-2].data / 100 + self.v = MlpModel( + input_size=input_size, + hidden_sizes=[512, 512, 256], + output_size=1 if linear_value_output else None, + ) + + def forward(self, observation, prev_action, prev_reward, init_rnn_state=None): + lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim) + assert not torch.any(torch.isnan(observation)), 'obs elem is nan' + + obs_flat = observation.reshape(T * B, -1) + if self.layer_norm: + obs_flat = torch.tanh(self.layer_norm(obs_flat)) + action = self.mu_mlp(obs_flat) + mu, std = (action[:, :self.action_size], action[:, self.action_size:]) + std = torch.log(1 + torch.exp(std)) # softplus + v = self.v(obs_flat).squeeze(-1) + + # Restore leading dimensions: [T,B], [B], or [], as input. + mu, std, v = restore_leading_dims((mu, std, v), lead_dim, T, B) + fake_rnn_state = torch.zeros(1, B, 1) + return mu, std, v, fake_rnn_state + diff --git a/rlpyt/models/popart_normalization.py b/rlpyt/models/popart_normalization.py new file mode 100644 index 00000000..aa2677c3 --- /dev/null +++ b/rlpyt/models/popart_normalization.py @@ -0,0 +1,105 @@ +""" +copy from https://github.com/aluscher/torchbeastpopart/blob/master/torchbeast/core/popart.py + +added support for multi-task learning +""" +import math +import torch + + +class PopArtLayer(torch.nn.Module): + + def __init__(self, input_features=256, output_features=1, beta=1e-4): + self.beta = beta + + super(PopArtLayer, self).__init__() + + self.input_features = input_features + self.output_features = output_features + + self.weight = torch.nn.Parameter(torch.Tensor(output_features, input_features)) + self.bias = torch.nn.Parameter(torch.Tensor(output_features)) + + self.register_buffer('mu', torch.zeros(output_features, requires_grad=False)) + self.register_buffer('sigma', torch.ones(output_features, requires_grad=False)) + + self.reset_parameters() + + def reset_parameters(self): + torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + torch.nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, inputs, task=None): + if len(inputs.shape) == 2: + inputs = inputs.unsqueeze(-1) + input_shape = inputs.shape + inputs = inputs.reshape(-1, self.input_features) + + normalized_output = inputs.mm(self.weight.t()) + normalized_output += self.bias.unsqueeze(0).expand_as(normalized_output) + normalized_output = normalized_output.reshape(*input_shape[:-1], self.output_features) + + with torch.no_grad(): + output = normalized_output * self.sigma + self.mu + + if task is not None: + output = output.gather(-1, task.unsqueeze(-1)) + normalized_output = normalized_output.gather(-1, task.unsqueeze(-1)) + + return [output.squeeze(-1), normalized_output.squeeze(-1)] + + @torch.no_grad() + def normalize(self, inputs, task=None): + """ + task: task ids + """ + task = torch.zeros(inputs.shape, dtype=torch.int64) if task is None else task + input_device = inputs.device + inputs = inputs.to(self.mu.device) + mu = self.mu.expand(*inputs.shape, self.output_features).gather(-1, task.unsqueeze(-1)).squeeze(-1) + sigma = self.sigma.expand(*inputs.shape, self.output_features).gather(-1, task.unsqueeze(-1)).squeeze(-1) + output = (inputs - mu) / sigma + return output.to(input_device) + + @torch.no_grad() + def update_parameters(self, vs, task): + """ + task: one hot vector of tasks + """ + vs, task = vs.to(self.mu.device), task.to(self.mu.device) + + oldmu = self.mu + oldsigma = self.sigma + + vs = vs * task + n = task.sum((0, 1)) + mu = vs.sum((0, 1)) / n + nu = torch.sum(vs ** 2, (0, 1)) / n + sigma = torch.sqrt(nu - mu ** 2) + sigma = torch.clamp(sigma, min=1e-2, max=1e+6) + + mu[torch.isnan(mu)] = self.mu[torch.isnan(mu)] + sigma[torch.isnan(sigma)] = self.sigma[torch.isnan(sigma)] + + self.mu = (1 - self.beta) * self.mu + self.beta * mu + self.sigma = (1 - self.beta) * self.sigma + self.beta * sigma + # print(f'new sigma: {self.sigma}#################################################3') + + self.weight.data = (self.weight.t() * oldsigma / self.sigma).t() + self.bias.data = (oldsigma * self.bias + oldmu - self.mu) / self.sigma + + def state_dict(self): + return dict(mu=self.mu, + sigma=self.sigma, + weight=self.weight.data, + bias=self.bias.data) + + def load_state_dict(self, state_dict): + with torch.no_grad(): + self.mu = state_dict['mu'] + self.sigma = state_dict['sigma'] + self.weight.data = state_dict['weight'] + self.bias.data = state_dict['bias'] diff --git a/rlpyt/utils/tensor.py b/rlpyt/utils/tensor.py index dfde1607..bfd5a3a4 100644 --- a/rlpyt/utils/tensor.py +++ b/rlpyt/utils/tensor.py @@ -84,3 +84,18 @@ def restore_leading_dims(tensors, lead_dim, T=1, B=1): assert B == 1 tensors = tuple(t.squeeze(0) for t in tensors) return tensors if is_seq else tensors[0] + + +def batched_quadratic_form(vector, matrix): + assert vector.shape[-1] == matrix.shape[-2] == matrix.shape[-1], 'received invalid shapes ' + str( + vector.shape) + str(matrix.shape) + + result = torch.matmul( + torch.matmul(vector.unsqueeze(-2), matrix) + , vector.unsqueeze(-1)) + return result.squeeze(-1).squeeze(-1) + + +def batched_trace(matrix): + trace = torch.sum(torch.diagonal(matrix, dim1=-2, dim2=-1), dim=-1) + return trace