Skip to content

Commit

Permalink
update ppo and a2c
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangchuheng123 committed Sep 5, 2018
1 parent bbda275 commit 9224a70
Show file tree
Hide file tree
Showing 3 changed files with 792 additions and 15 deletions.
108 changes: 93 additions & 15 deletions TRPO.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,72 @@
"""
Implementation of TRPO
ref: Schulman, John, et al. "Trust region policy optimization." International Conference on Machine Learning. 2015.
NOTICE:
`Tensor2` means 2D-Tensor (num_samples, num_dims)
"""

import gym
import torch
import torch.nn as nn
import torch.optim as opt
from torch import Tensor
from torch.autograd import Variable
from collections import namedtuple
from itertools import count
import scipy.optimize as sciopt
from itertools import count
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from os.path import join as joindir
import pandas as pd
import numpy as np
import argparse
import datetime
import math


Transition = namedtuple('Transition', ('state', 'action', 'mask', 'next_state', 'reward'))
OldCopy = namedtuple('OldCopy', ('log_density', 'action_mean', 'action_log_std', 'action_std'))
EPS = 1e-10
RESULT_DIR = '../result'


class args(object):
env_name = 'Reacher-v2'
env_name = 'Hopper-v2'
seed = 1234
num_episode = 100
batch_size = 1000
max_step_per_episode = 100
num_episode = 1000
batch_size = 5000
max_step_per_episode = 200
gamma = 0.995
tau = 0.97
lamda = 0.97
l2_reg = 1e-3
value_opt_max_iter = 25
damping = 0.1
max_kl = 1e-2
cg_nsteps = 10
log_num_episode = 1
num_parallel_run = 5

def add_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--env_name', type=str, default='Hopper-v2')
parser.add_argument('--seed', type=int, default=1234)
parser.add_argument('--num_episode', type=int, default=1000)
parser.add_argument('--batch_size', type=int, default=5000)
parser.add_argument('--max_step_per_episode', type=int, default=200)
parser.add_argument('--gamma', type=float, default=0.995)
parser.add_argument('--lamda', type=float, default=0.97)
parser.add_argument('--l2_reg', type=float, default=1e-3)
parser.add_argument('--value_opt_max_iter', type=int, default=25)
parser.add_argument('--damping', type=float, default=0.1)
parser.add_argument('--max_kl', type=float, default=1e-2)
parser.add_argument('--cg_nsteps', type=int, default=10)
parser.add_argument('--log_num_episode', type=int, default=1)
parser.add_argument('--num_parallel_run', type=int, default=5)

args = parser.parse_args()
return args

class Policy(nn.Module):
def __init__(self, num_inputs, num_outputs):
Expand Down Expand Up @@ -411,8 +452,8 @@ def line_search(policy_net, get_loss, full_step, grad, max_num_backtrack=10, acc
return True, xnew
alpha *= 0.5
return False, x0
if __name__ == '__main__':

def trpo(args):
env = gym.make(args.env_name)
num_inputs = env.observation_space.shape[0]
num_actions = env.action_space.shape[0]
Expand All @@ -427,10 +468,12 @@ def line_search(policy_net, get_loss, full_step, grad, max_num_backtrack=10, acc
running_reward = ZFilter((1,), demean=False, clip=10)

reward_record = []
global_steps = 0

for i_episode in range(args.num_episode):
memory = Memory()

# sample data
# sample data: single path method
num_steps = 0
while num_steps < args.batch_size:
state = env.reset()
Expand All @@ -453,8 +496,10 @@ def line_search(policy_net, get_loss, full_step, grad, max_num_backtrack=10, acc

state = next_state

reward_record.append(reward_sum)
num_steps += (t + 1)
global_steps += (t + 1)
reward_record.append({'steps': global_steps, 'reward': reward_sum})

batch = memory.sample()
batch_size = len(memory)

Expand All @@ -475,13 +520,14 @@ def line_search(policy_net, get_loss, full_step, grad, max_num_backtrack=10, acc
for i in reversed(range(batch_size)):
returns[i] = rewards[i] + args.gamma * prev_return * masks[i]
deltas[i] = rewards[i] + args.gamma * prev_value * masks[i] - values[i]
# ref: https://arxiv.org/pdf/1506.02438.pdf
advantages[i] = deltas[i] + args.gamma * args.tau * prev_advantage * masks[i]
# ref: https://arxiv.org/pdf/1506.02438.pdf (generalization advantage estimate)
# notation following PPO paper
advantages[i] = deltas[i] + args.gamma * args.lamda * prev_advantage * masks[i]

prev_return = returns[i]
prev_value = values[i]
prev_advantage = advantages[i]
advantages = (advantages - advantages.mean()) / advantages.std()
advantages = (advantages - advantages.mean()) / (advantages.std() + EPS)

# optimize value network
loss_func_args = (value_net, states, returns)
Expand Down Expand Up @@ -522,6 +568,38 @@ def line_search(policy_net, get_loss, full_step, grad, max_num_backtrack=10, acc
print('PolicyNet optimization: old loss = {}, new loss = {}'.format(old_loss, new_loss))

if i_episode % args.log_num_episode == 0:
print('Finished episode: {} Mean Reward: {}'.format(i_episode, np.mean(reward_record)))
print('Finished episode: {} Mean Reward: {}'.format(i_episode, reward_record[-1]))
print('-----------------')


return reward_record


if __name__ == '__main__':
datestr = datetime.datetime.now().strftime('%Y-%m-%d')
args = add_arguments()

record_dfs = pd.DataFrame(columns=['steps', 'reward'])
reward_cols = []
for i in range(args.num_parallel_run):
args.seed += 1
reward_record = pd.DataFrame(trpo(args))
record_dfs = record_dfs.merge(reward_record, how='outer', on='steps', suffixes=('', '_{}'.format(i)))
reward_cols.append('reward_{}'.format(i))

record_dfs = record_dfs.drop(columns='reward').sort_values(by='steps', ascending=True).ffill().bfill()
record_dfs['reward_mean'] = record_dfs[reward_cols].mean(axis=1)
record_dfs['reward_std'] = record_dfs[reward_cols].std(axis=1)
record_dfs['reward_smooth'] = record_dfs['reward_mean'].ewm(span=20).mean()
record_dfs.to_csv(joindir(RESULT_DIR, 'trpo-record-{}-{}.csv'.format(args.env_name, datestr)))

# Plot
plt.figure(figsize=(12, 6))
plt.plot(record_dfs['steps'], record_dfs['reward_mean'], label='trajory reward')
plt.plot(record_dfs['steps'], record_dfs['reward_smooth'], label='smoothed reward')
plt.fill_between(record_dfs['steps'], record_dfs['reward_mean'] - record_dfs['reward_std'],
record_dfs['reward_mean'] + record_dfs['reward_std'], color='b', alpha=0.2)
plt.legend()
plt.xlabel('steps of env interaction (sample complexity)')
plt.ylabel('average reward')
plt.title('TRPO on {}'.format(args.env_name))
plt.savefig(joindir(RESULT_DIR, 'trpo-{}-{}.pdf'.format(args.env_name, datestr)))
Loading

0 comments on commit 9224a70

Please sign in to comment.