-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
- Loading branch information
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
absl-py==0.3.0 | ||
astor==0.7.1 | ||
autopep8==1.3.5 | ||
backcall==0.1.0 | ||
bleach==2.1.4 | ||
certifi==2018.8.24 | ||
chardet==3.0.4 | ||
colorama==0.3.9 | ||
cycler==0.10.0 | ||
decorator==4.3.0 | ||
defusedxml==0.5.0 | ||
entrypoints==0.2.3 | ||
gast==0.2.0 | ||
grpcio==1.14.1 | ||
html5lib==1.0.1 | ||
idna==2.7 | ||
ipykernel==5.0.0 | ||
ipython==7.0.1 | ||
ipython-genutils==0.2.0 | ||
ipywidgets==7.4.2 | ||
isort==4.3.4 | ||
jedi==0.12.1 | ||
Jinja2==2.10 | ||
jsonschema==2.6.0 | ||
jupyter==1.0.0 | ||
jupyter-client==5.2.3 | ||
jupyter-console==5.2.0 | ||
jupyter-core==4.4.0 | ||
kiwisolver==1.0.1 | ||
lxml==4.2.5 | ||
Markdown==2.6.11 | ||
MarkupSafe==1.0 | ||
matplotlib==2.2.2 | ||
mccabe==0.6.1 | ||
mistune==0.8.3 | ||
nbconvert==5.4.0 | ||
nbformat==4.4.0 | ||
nltk==3.3 | ||
notebook==5.7.0 | ||
numpy==1.14.5 | ||
opencv-python==3.4.2.17 | ||
pandas==0.23.4 | ||
pandas-datareader==0.7.0 | ||
pandocfilters==1.4.2 | ||
parso==0.3.1 | ||
pickleshare==0.7.5 | ||
Pillow==5.2.0 | ||
prometheus-client==0.3.1 | ||
prompt-toolkit==1.0.15 | ||
protobuf==3.6.0 | ||
pycodestyle==2.4.0 | ||
Pygments==2.2.0 | ||
pyparsing==2.2.0 | ||
python-dateutil==2.7.3 | ||
pytz==2018.5 | ||
pywinpty==0.5.4 | ||
pyzmq==17.1.2 | ||
qtconsole==4.4.1 | ||
requests==2.19.1 | ||
scikit-learn==0.19.2 | ||
scipy==1.1.0 | ||
Send2Trash==1.5.0 | ||
simplegeneric==0.8.1 | ||
six==1.11.0 | ||
tensorboard==1.10.0 | ||
tensorboardX==1.4 | ||
tensorflow==1.10.0 | ||
termcolor==1.1.0 | ||
terminado==0.8.1 | ||
testpath==0.4.1 | ||
torch==0.4.1 | ||
torchfile==0.1.0 | ||
torchnet==0.0.4 | ||
torchvision==0.2.1 | ||
tornado==5.1.1 | ||
traitlets==4.3.2 | ||
urllib3==1.23 | ||
visdom==0.1.8.5 | ||
wcwidth==0.1.7 | ||
webencodings==0.5.1 | ||
websocket-client==0.53.0 | ||
Werkzeug==0.14.1 | ||
widgetsnbextension==3.4.2 | ||
wrapt==1.10.11 | ||
xgboost==0.80 |
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
import os | ||
import gym | ||
import numpy as np | ||
from copy import deepcopy | ||
from itertools import chain | ||
from collections import deque | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.distributions import Categorical | ||
|
||
env = gym.make('CartPole-v1') | ||
env = env.unwrapped | ||
state_number = env.observation_space.shape[0] | ||
action_number = env.action_space.n | ||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | ||
|
||
class Actor(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.layers = nn.Sequential( | ||
nn.Linear(state_number, 32), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(32, 32), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(32, action_number), | ||
nn.Softmax(dim=-1), | ||
) | ||
|
||
def forward(self, state): | ||
pi = self.layers(state) # (batch_size, action_number) | ||
return pi | ||
|
||
class Critic(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.layers = nn.Sequential( | ||
nn.Linear(state_number, 32), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(32, 32), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(32, 1), | ||
) | ||
|
||
def forward(self, state): | ||
value = self.layers(state).squeeze(-1) # (batch_size,) | ||
return value | ||
|
||
class ActorCritic(): | ||
|
||
def __init__( | ||
self, | ||
gamma=0.99, | ||
update_steps=1, | ||
lr=5e-4, | ||
weight_decay=0.0, | ||
): | ||
self.gamma = gamma | ||
self.update_steps = update_steps | ||
|
||
self.buffer = [] | ||
self.actor = Actor().to(device) | ||
self.critic = Critic().to(device) | ||
self.optimizer = torch.optim.Adam( | ||
chain(self.actor.parameters(), self.critic.parameters()), | ||
lr=lr, weight_decay=weight_decay | ||
) | ||
self.loss_fct = nn.SmoothL1Loss() | ||
|
||
@torch.no_grad() | ||
def choose_action(self, state): | ||
state = torch.from_numpy(state).float().unsqueeze(0).to(device) | ||
pi = self.actor(state) | ||
dist = torch.distributions.Categorical(pi) | ||
action = dist.sample().item() | ||
return action | ||
|
||
@torch.no_grad() | ||
def get_value(self, state): | ||
state = torch.from_numpy(state).float().unsqueeze(0).to(device) | ||
value = self.critic(state) | ||
return value | ||
|
||
def store_experience(self, experience): | ||
self.buffer.append(experience) | ||
|
||
def update(self): | ||
# 得到数据 | ||
get_tensor = lambda x: torch.tensor([b[x] for b in self.buffer]).to(device) | ||
states = get_tensor(0).float() | ||
actions = get_tensor(1).long() | ||
rewards = get_tensor(2).float() | ||
next_states = get_tensor(3).float() | ||
done = get_tensor(4).long() | ||
|
||
# # 改进2:为每步t赋予不同权重 | ||
# for t in reversed(range(0, rewards.size(0) - 1)): | ||
# rewards[t] = rewards[t] + self.gamma * rewards[t + 1] | ||
# 改进1:增加一个奖励基准$b$,这里用均值;另归一化,有助于收敛 | ||
rewards = (rewards - rewards.mean()) / rewards.std() | ||
|
||
# 计算target | ||
with torch.no_grad(): | ||
# 动作价值函数 Q^{\pi}(s, a) = r(s, a) + \gamma \sum_{s' \in S} P(s'|s, a) V^{\pi}(s') | ||
target_v = rewards + self.gamma * self.critic(next_states) | ||
# 优势函数 A^{\pi}(s, a) = Q^{\pi}(s, a) - V^{\pi}(s) | ||
advantage = target_v - self.critic(states) | ||
|
||
for i in range(self.update_steps): | ||
# 计算损失 | ||
pi = self.actor(states) | ||
action_log_probs = torch.sum(pi.log() * F.one_hot(actions), dim=1) | ||
|
||
loss_actor = - (action_log_probs * advantage).mean() # 基于TD误差 | ||
|
||
value = self.critic(states) | ||
loss_critic = self.loss_fct(value, target_v) | ||
|
||
loss = loss_actor + loss_critic | ||
self.optimizer.zero_grad() | ||
loss.backward() | ||
self.optimizer.step() | ||
|
||
# 清除缓存 | ||
del self.buffer[:] | ||
|
||
return loss.item() | ||
|
||
def train(agent, num_episodes=5000, render=False): | ||
step = 0 | ||
for i in range(num_episodes): | ||
total_rewards = 0 | ||
done = False | ||
state, _ = env.reset() | ||
while not done: | ||
step += 1 | ||
if render: env.render() | ||
# 选择动作 | ||
action = agent.choose_action(state) | ||
# 与环境产生交互 | ||
next_state, reward, done, truncated, info = env.step(action) | ||
# 预处理,修改reward,你也可以不修改奖励,直接用reward,都能收敛 | ||
x, x_dot, theta, theta_dot = next_state | ||
r1 = (env.x_threshold - abs(x)) / env.x_threshold - 0.8 | ||
r2 = (env.theta_threshold_radians - abs(theta)) / env.theta_threshold_radians - 0.5 | ||
r3 = 3 * r1 + r2 | ||
# 经验缓存 | ||
agent.store_experience((state, action, r3, next_state, done)) | ||
# 更新状态 | ||
state = next_state | ||
total_rewards += reward | ||
|
||
# 回合结束,更新参数 | ||
loss = agent.update() | ||
if i % 50 == 0: | ||
print('episode:{} reward:{}'.format(i, total_rewards)) | ||
|
||
def test(agent, num_episodes=10, render=False): | ||
env = gym.make('CartPole-v1', render_mode="human" if render else None) | ||
step = 0 | ||
eval_rewards = [] | ||
for i in range(num_episodes): | ||
total_rewards = 0 | ||
done = False | ||
state, _ = env.reset() | ||
while not done: | ||
step += 1 | ||
if render: env.render() | ||
# 选择动作 | ||
action = agent.choose_action(state) | ||
# 与环境产生交互 | ||
next_state, reward, done, truncated, info = env.step(action) | ||
# 更新状态 | ||
state = next_state | ||
total_rewards += reward | ||
eval_rewards.append(total_rewards) | ||
return sum(eval_rewards) / len(eval_rewards) | ||
|
||
if __name__ == "__main__": | ||
agent = ActorCritic() | ||
train(agent, render=False) | ||
test(agent, render=True) |