Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add V-MPO #194

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions examples/mujoco_vmpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@

"""
Runs one instance of the Mujoco environment and optimizes using V-MPO algorithm.
"""

from rlpyt.samplers.serial.sampler import SerialSampler
from rlpyt.runners.minibatch_rl import MinibatchRlEval
from rlpyt.utils.logging.context import logger_context
from rlpyt.algos.pg.v_mpo import VMPO
from rlpyt.agents.pg.gaussian_vmpo_agent import MujocoVmpoAgent
from rlpyt.models.pg.mujoco_ff_model import MujocoVmpoFfModel
from rlpyt.envs.gym import make as gym_make
from rlpyt.samplers.parallel.cpu.sampler import CpuSampler
from rlpyt.utils.launching.affinity import make_affinity


def build_and_train(id="Ant-v3", run_ID=0, cuda_idx=None):
affinity = make_affinity(n_cpu_core=24, cpu_per_run=24, n_gpu=0, set_affinity=True)
sampler = CpuSampler(
EnvCls=gym_make,
env_kwargs=dict(id=id),
eval_env_kwargs=dict(id=id),
batch_T=40, # Four time-steps per sampler iteration.
batch_B=64 * 100,
max_decorrelation_steps=100,
eval_n_envs=1,
eval_max_steps=int(10e8),
eval_max_trajectories=8,
)
algo = VMPO(T_target_steps=100, pop_art_reward_normalization=True, discrete_actions=False, epochs=1)
agent = MujocoVmpoAgent(ModelCls=MujocoVmpoFfModel, model_kwargs=dict(linear_value_output=False, layer_norm=True))
runner = MinibatchRlEval(
algo=algo,
agent=agent,
sampler=sampler,
n_steps=int(1e10),
log_interval_steps=int(1e6),
affinity=affinity
)
config = dict(id=id)
name = "vmpo_" + id
log_dir = "vmpo_mujoco"
with logger_context(log_dir, run_ID, name, config, snapshot_mode="last"):
runner.train()


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--id', help='gym env id', default='Ant-v3')
parser.add_argument('--run_ID', help='run identifier (logging)', type=int, default=0)
parser.add_argument('--cuda_idx', help='gpu to use ', type=int, default=None)
args = parser.parse_args()
build_and_train(
id=args.id,
run_ID=args.run_ID,
cuda_idx=args.cuda_idx,
)
67 changes: 67 additions & 0 deletions rlpyt/agents/pg/gaussian_vmpo_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from rlpyt.distributions.gaussian import Gaussian
import torch
from rlpyt.agents.pg.base import AgentInfo, AgentInfoRnn
from rlpyt.utils.buffer import buffer_to, buffer_func, buffer_method
from rlpyt.agents.pg.mujoco import MujocoMixin, AlternatingRecurrentGaussianPgAgent
from rlpyt.agents.base import (AgentStep, BaseAgent, RecurrentAgentMixin,
AlternatingRecurrentAgentMixin)
from rlpyt.utils.collections import namedarraytuple
from rlpyt.distributions.gaussian import DistInfoStd

DistInfo = namedarraytuple("DistInfo", ["mean", 'std'])


class VmpoAgent(RecurrentAgentMixin, BaseAgent):
"""
Base class for gaussian vmpo agents. This version uses a Gaussian with diagonal covariance matrix. It expects a
mu and std vector from the agent. The agent should have applied a softmax to the std to ensure positive values.
Exp on the std output seems to workd slightly worse.
"""

def initialize(self, env_spaces, *args, **kwargs):
"""Extends base method to build Gaussian distribution."""
super().initialize(env_spaces, *args, **kwargs)
self.distribution = Gaussian(
dim=env_spaces.action.shape[0],
# squash=env_spaces.action.high[0],
# min_std=MIN_STD,
# clip=env_spaces.action.high[0], # Probably +1?
)

def __call__(self, observation, prev_action, prev_reward, init_rnn_state):
"""Performs forward pass on training data, for algorithm."""
model_inputs = buffer_to((observation, prev_action, prev_reward, init_rnn_state),
device=self.device)
mu, std, value, rnn_state = self.model(*model_inputs)
dist_info, value = buffer_to((DistInfoStd(mean=mu, log_std=std), value), device="cpu")
return dist_info, value, rnn_state

@torch.no_grad()
def step(self, observation, prev_action, prev_reward):
agent_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device)
mu, std, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state)
dist_info = DistInfoStd(mean=mu, log_std=std)
# action = self.distribution.sample(dist_info) if self._mode == 'sample' else mu
dist = torch.distributions.normal.Normal(loc=mu, scale=std)
action = dist.sample() if self._mode == 'sample' else mu
if self.prev_rnn_state is None:
prev_rnn_state = buffer_func(rnn_state, torch.zeros_like)
else:
prev_rnn_state = self.prev_rnn_state

# Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
# (Special case: model should always leave B dimension in.)
prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
agent_info = AgentInfoRnn(dist_info=dist_info, value=value,
prev_rnn_state=prev_rnn_state)
action, agent_info = buffer_to((action, agent_info), device="cpu")
self.advance_rnn_state(rnn_state) # Keep on device.
return AgentStep(action=action, agent_info=agent_info)


class MujocoVmpoAgent(MujocoMixin, VmpoAgent):
pass


class AlternatingVmpoAgent(AlternatingRecurrentAgentMixin, MujocoVmpoAgent):
pass
245 changes: 245 additions & 0 deletions rlpyt/algos/pg/v_mpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import torch
import numpy as np
from rlpyt.utils.tensor import batched_trace, batched_quadratic_form
from rlpyt.algos.base import RlAlgorithm
from rlpyt.models.popart_normalization import PopArtLayer
from rlpyt.algos.pg.base import OptInfo
from rlpyt.agents.base import AgentInputs
from rlpyt.utils.tensor import valid_mean
from rlpyt.utils.quick_args import save__init__args
from rlpyt.utils.buffer import buffer_to
from rlpyt.utils.collections import namedarraytuple, namedtuple
from rlpyt.utils.misc import iterate_mb_idxs
from rlpyt.algos.utils import (discount_return, generalized_advantage_estimation, valid_from_done)

OptInfo = namedtuple("OptInfo", ["loss", 'pi_loss', 'eta_loss', 'alpha_loss', 'value_loss',
'alpha_mu_loss', 'alpha_sigma_loss',
'mu_kl', 'sigma_kl', 'advantage', 'normalized_return',
'alpha', 'eta', 'alpha_mu', 'alpha_sigma',
'pi_mu', 'pi_log_std', 'policy_kl',
"entropy", "perplexity"])

class VMPO(RlAlgorithm):
opt_info_fields = tuple(f for f in OptInfo._fields) # copy
bootstrap_value = False # Tells the sampler it needs Value(State')

def __init__(
self,
discount=0.99,
learning_rate=1e-4,
T_target_steps=100,
bootstrap_with_online_model=False,
OptimCls=torch.optim.Adam,
pop_art_reward_normalization=True,
optim_kwargs=None,
initial_optim_state_dict=None,
epochs=1,
discrete_actions=False,
epsilon_eta=0.01,
epsilon_alpha=0.01,
initial_eta=1.0,
initial_alpha=5.0,
initial_alpha_mu=1.0,
initial_alpha_sigma=1.0,
epsilon_alpha_mu=0.0075,
epsilon_alpha_sigma=1e-5,
):
"""Saves input settings."""
if optim_kwargs is None:
optim_kwargs = dict()
self.pop_art_normalizer = PopArtLayer()
save__init__args(locals())

def initialize(self, agent, n_itr, batch_spec, mid_batch_reset=False,
examples=None, world_size=1, rank=0):
self.agent = agent
self.alpha = torch.autograd.Variable(torch.ones(1) * self.initial_alpha, requires_grad=True)
self.alpha_mu = torch.autograd.Variable(torch.ones(1) * self.initial_alpha_mu, requires_grad=True)
self.alpha_sigma = torch.autograd.Variable(torch.ones(1) * self.initial_alpha_sigma, requires_grad=True)
self.eta = torch.autograd.Variable(torch.ones(1) * self.initial_eta, requires_grad=True)

self.optimizer = self.OptimCls(list(self.agent.parameters()) +
list(self.pop_art_normalizer.parameters()) +
[self.alpha, self.alpha_mu, self.alpha_sigma, self.eta],
lr=self.learning_rate, **self.optim_kwargs)
if self.initial_optim_state_dict is not None:
self.load_optim_state_dict(self.initial_optim_state_dict)
self.n_itr = n_itr
self.batch_spec = batch_spec
self.mid_batch_reset = mid_batch_reset
self.rank = rank
self.world_size = world_size
self._batch_size = self.batch_spec.size // self.T_target_steps # For logging.

def process_returns(self, reward, done, value_prediction,
action, dist_info, old_dist_info, opt_info):
done = done.type(reward.dtype)
if self.pop_art_reward_normalization:
unnormalized_value = value_prediction
value_prediction, normalized_value = self.pop_art_normalizer(value_prediction)

bootstrap_value = value_prediction[-1]
reward, value_prediction, done = reward[:-1], value_prediction[:-1], done[:-1]

return_ = discount_return(reward, done, bootstrap_value.detach(), self.discount)
if self.pop_art_reward_normalization:
self.pop_art_normalizer.update_parameters(return_.unsqueeze(-1),
torch.ones_like(return_.unsqueeze(-1)))
_, normalized_value = self.pop_art_normalizer(unnormalized_value[:-1])
return_ = self.pop_art_normalizer.normalize(return_)
advantage = return_ - normalized_value.detach()
value_prediction = normalized_value
opt_info.normalized_return.append(return_.numpy())
else:
advantage = return_ - value_prediction.detach()

valid = valid_from_done(done) # Recurrent: no reset during training.
opt_info.advantage.append(advantage.numpy())

loss, opt_info = self.loss(dist_info=dist_info[:-1],
value=value_prediction,
action=action[:-1],
return_=return_,
advantage=advantage.detach(),
valid=valid,
old_dist_info=old_dist_info[:-1],
opt_info=opt_info)
return loss, opt_info

def optimize_agent(self, itr, samples=None, sampler_itr=None):
"""
Train the agent, for multiple epochs over minibatches taken from the
input samples. Organizes agent inputs from the training data, and
moves them to device (e.g. GPU) up front, so that minibatches are
formed within device, without further data transfer.
"""
opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
agent_inputs = AgentInputs( # Move inputs to device once, index there.
observation=samples.env.observation,
prev_action=samples.agent.prev_action,
prev_reward=samples.env.prev_reward,
)
agent_inputs = buffer_to(agent_inputs, device=self.agent.device)
init_rnn_state = buffer_to(samples.agent.agent_info.prev_rnn_state[0], device=self.agent.device)
T, B = samples.env.reward.shape[:2]
mb_size = B // self.T_target_steps
for _ in range(self.epochs):
for idxs in iterate_mb_idxs(B, mb_size, shuffle=True):
self.optimizer.zero_grad()
dist_info, value, _ = self.agent(*agent_inputs[:, idxs], init_rnn_state[idxs])
loss, opt_info = self.process_returns(samples.env.reward[:, idxs],
done=samples.env.done[:, idxs],
value_prediction=value.cpu(),
action=samples.agent.action[:, idxs],
dist_info=dist_info,
old_dist_info=samples.agent.agent_info.dist_info[:, idxs],
opt_info=opt_info)
loss.backward()
self.optimizer.step()
self.clamp_lagrange_multipliers()
opt_info.loss.append(loss.item())
self.update_counter += 1
return opt_info

def loss(self, dist_info, value, action, return_, advantage, valid, old_dist_info, opt_info):
T, B = tuple(action.shape[:2])
advantage = advantage.clamp_max(60) # clamp due to numerical instabilities in logsumexp
num_valid = valid.sum().type(torch.int32)
# map advantages to positve values so that topk with valid mask gives top 50 % of valid advantages
top_advantages, top_advantages_indeces = torch.topk((valid * torch.exp(advantage / 100)).reshape(T * B),
num_valid // 2)
top_advantages = torch.log(top_advantages) * 100
advantage_mask = torch.zeros_like(advantage.view(T * B))
advantage_mask[top_advantages_indeces] = 1
advantage_mask = advantage_mask.reshape(T, B)

log_advantage_sum = torch.logsumexp(top_advantages / self.eta, dim=0)
phi = torch.exp((advantage/self.eta) - log_advantage_sum)
value_error = 0.5 * (value - return_) ** 2
value_loss = valid_mean(value_error, valid)
eta_loss = self.eta * self.epsilon_eta + self.eta * (log_advantage_sum - torch.log(0.5 * num_valid))

if self.discrete_actions:
pi_loss, alpha_loss, opt_info = self.discrete_actions_loss(advantage_mask, phi, action, dist_info,
old_dist_info, opt_info)
loss = pi_loss + value_loss + eta_loss + alpha_loss
else:
pi_loss, alpha_loss, opt_info = self.continuous_actions_loss(advantage_mask, phi, action, dist_info,
old_dist_info, valid, opt_info)
loss = pi_loss + value_loss + eta_loss + alpha_loss

opt_info.pi_loss.append(pi_loss.item())
opt_info.eta_loss.append(eta_loss.item())
opt_info.value_loss.append(value_loss.item())
opt_info.eta.append(self.eta.item())
return loss, opt_info

def discrete_actions_loss(self, advantage_mask, phi, action, dist_info, old_dist_info, opt_info):
dist = self.agent.distribution
pi_loss = - torch.sum(advantage_mask * (phi.detach() * dist.log_likelihood(action.contiguous(), dist_info)))
policy_kl = dist.kl(old_dist_info, dist_info)
alpha_loss = valid_mean(
self.alpha * (self.epsilon_alpha - policy_kl.detach()) + self.alpha.detach() * policy_kl)
opt_info.alpha_loss.append(alpha_loss.item())
opt_info.alpha.append(self.alpha.item())
opt_info.policy_kl.append(policy_kl.mean().item())
return pi_loss, alpha_loss, opt_info

def continuous_actions_loss(self, advantage_mask, phi, action, dist_info, old_dist_info, valid, opt_info):
d = np.prod(action.shape[-1])
distribution = torch.distributions.normal.Normal(loc=dist_info.mean, scale=dist_info.log_std)
pi_loss = - torch.sum(advantage_mask * (phi.detach() * distribution.log_prob(action).sum(dim=-1)))
# pi_loss = - torch.sum(advantage_mask * (phi.detach() * self.agent.distribution.log_likelihood(action, dist_info)))
new_std = dist_info.log_std
old_std = old_dist_info.log_std
old_covariance = torch.diag_embed(old_std)
old_covariance_inverse = torch.diag_embed(1 / old_std)
new_covariance_inverse = torch.diag_embed(1 / new_std)
old_covariance_determinant = torch.prod(old_std, dim=-1)
new_covariance_determinant = torch.prod(new_std, dim=-1)

mu_kl = 0.5 * batched_quadratic_form(dist_info.mean - old_dist_info.mean, old_covariance_inverse)
trace = batched_trace(torch.matmul(new_covariance_inverse, old_covariance))
sigma_kl = 0.5 * (trace - d + torch.log(new_covariance_determinant / old_covariance_determinant))
alpha_mu_loss = valid_mean(
self.alpha_mu * (self.epsilon_alpha_mu - mu_kl.detach()) + self.alpha_mu.detach() * mu_kl, valid)
alpha_sigma_loss = valid_mean(self.alpha_sigma * (
self.epsilon_alpha_sigma - sigma_kl.detach()) + self.alpha_sigma.detach() * sigma_kl, valid)
opt_info.alpha_mu.append(self.alpha_mu.item())
opt_info.alpha_sigma.append(self.alpha_sigma.item())
opt_info.alpha_mu_loss.append(alpha_mu_loss.item())
opt_info.mu_kl.append(valid_mean(mu_kl, valid).item())
opt_info.sigma_kl.append(valid_mean(sigma_kl, valid).item())
opt_info.alpha_sigma_loss.append(valid_mean(self.epsilon_alpha_sigma - sigma_kl, valid).item())
opt_info.pi_mu.append(dist_info.mean.mean().item())
opt_info.pi_log_std.append(dist_info.log_std.mean().item())
return pi_loss, alpha_mu_loss + alpha_sigma_loss, opt_info

def clamp_lagrange_multipliers(self):
"""
As described in the paper alpha and eta are lagrange multipliers that must be positive. That's why
they are clamped after every update
"""
with torch.no_grad():
self.alpha.clamp_min_(1e-8)
self.alpha_mu.clamp_min_(1e-8)
self.alpha_sigma.clamp_min_(1e-8)
self.eta.clamp_min_(1e-8)

def optim_state_dict(self):
return dict(
optimizer=self.optimizer.state_dict(),
eta=self.eta,
alpha=self.alpha,
alpha_mu=self.alpha_mu,
alpha_sigma=self.alpha_sigma,
pop_art_layer=self.pop_art_normalizer.state_dict()
)

def load_optim_state_dict(self, state_dict):
self.optimizer.load_state_dict(state_dict["optimizer"])
self.pop_art_normalizer.load_state_dict(state_dict['pop_art_layer'])
self.eta.data = state_dict['eta']
self.alpha.data = state_dict['alpha']
self.alpha_mu.data = state_dict['alpha_mu']
self.alpha_sigma.data = state_dict['alpha_sigma']
Loading