Skip to content

機械学習

seigot edited this page Jan 30, 2022 · 48 revisions

機械学習(TGS:Torii Gakushu System)に関する情報共有ページ

準備

# 追記:https://github.com/seigot/burger_war_dev/wikiのstep2.のpackage追加の部分 〜 準備を実行するスクリプト
# これが実行できたら実行モードに進める
mkdir -p ~/tmp; cd ~/tmp;
sudo curl -LJO https://raw.githubusercontent.com/seigot/tools/master/robot/install-kenjirotorii.sh
bash install-kenjirotorii.sh

https://github.com/seigot/burger_war_dev/wikiのstep2.のpackage追加の部分を全てインストールしておく。
その後、元リポジトリ or fork したリポジトリを取得

cd $HOME/catkin_ws/src
git clone https://github.com/kenjirotorii/burger_war_kit
git clone https://github.com/kenjirotorii/burger_war_dev

# obstacle detectorの取得
cd ~/catkin_ws/src
git clone https://github.com/tysik/obstacle_detector.git

# catkin build
cd $HOME/catkin_ws
catkin build
source ~/catkin_ws/devel/setup.bash

# library install
cd $HOME/catkin_ws/src/burger_war_dev
pip install -r requirements.txt

実行モード

普通に実行する

## ターミナルから実行
cd ~/catkin_ws/src/burger_war_kit
bash scripts/sim_with_judge.sh

## 別のターミナルから実行
cd ~/catkin_ws/src/burger_war_kit
bash scripts/start.sh
bash scripts/start.sh -l2   # level2
bash scripts/start.sh -l3   # level3

※訓練モード実施後の場合は、訓練モードのパッチを外してから実行する

cd ~/catkin_ws/src/burger_war_dev
patch -p1 -R < 20211114_burger_war_dev.patch # 訓練モードのパッチを外してから実行する

訓練モード

学習のためには以下のstepを踏む必要がある。

https://github.com/kenjirotorii/burger_war_dev/blob/main/burger_war_dev/scripts/dqn_learning.py
step1.serverを立ち上げ
step2.launchを訓練モードにする(以下のpatchを当てる)
step3.普通に実行する(sim_with_judge.sh/start.sh)を実行

step1.serverを立ち上げ

cd ~/catkin_ws/src/burger_war_dev/scripts
python dqn_learning.py

学習モデルを保存するパスが、実行時のパスに依存しているため以下を実行する

cd ~/catkin_ws
python src/burger_war_dev/burger_war_dev/scripts/dqn_learning.py
 ---> waiting agent client: port 5010... と表示される
# たまに、不意にserverがdisconnect状態に陥る(socket.error等による)
# この場合は強制終了→再起動するとよい、はず。

# 強制終了
# ps -aux | grep dqn_learning # ProcessIDを検索
ProcessID=`ps -aux | grep dqn_learning | sed -e 's/  */ /g' | cut -d' ' -f2`
echo ${ProcessID}
kill -s KILL ${ProcessID}

# 再起動
python src/burger_war_dev/burger_war_dev/scripts/dqn_learning.py

step2.launchを訓練モードにする(以下のpatchを当てる)

start.sh実行時に以下を切り替えるpatch dqn_operation.py:訓練モード用のファイル ★これに切り替える
dqn_self_play.py:実行モード用のファイル ★これはデフォルト

cd ~/catkin_ws/src/burger_war_dev
wget https://raw.githubusercontent.com/seigot/tools/master/robot/20211114_burger_war_dev.patch
patch -p1 < 20211114_burger_war_dev.patch  # test.patchは以下をコピーして作成

必要に応じて以下をコメントアウトする(GPU RTX 3060などは非対応のため明示的にcpuを使う必要があるらしい)

emacs ~/catkin_ws/src/burger_war_dev/burger_war_dev/scripts/agents/brain.py
self.device = torch.device('cpu') #temporary setting

20211114_burger_war_dev.patchは後述

step3.実行する(sim_with_judge.sh/start.shを実行する)

# step1.の手順でAgentServerを起動している前提
# 別ターミナルから、autotestを実行(sim_with_judge.sh, start.sh を繰り返し起動)

# まずautotest.shにパッチを当てる(訓練相手を決定する)
cd ~/catkin_ws/src/burger_war_kit
wget https://raw.githubusercontent.com/seigot/tools/master/robot/20211114_tgs_autotest.patch
patch -p1 < 20211114_tgs_autotest.patch

# 実行する
cd burger_war/launch
bash ../../autotest/autotest.sh

docker利用版
(こちらは任意。ROS melodic環境がdockerに含まれているdockerを利用する版。結果は上記と同じになるはず。
 ただ環境によってはエラーになる可能性を感じる。dockerに関するエラー解決する自信がなければやめておいた方が時間ロスしなくていいかもしれない)

## 参考URL(https://github.com/seigot/burger_war_dev/blob/main/STARTUP_GUIDE.md)

cd ~/catkin_ws/src/burger_war_dev

## インストール & build
bash commands/docker-install.sh amd64
bash commands/docker-build.sh

## docker起動
bash commands/docker-launch.sh

## gazeboが起動するか確認
bash commands/kit.sh -c gazebo
  --> ここでgazeboが起動しなければこちら記載のエラー解決が必要、自信がなければやめておいた方が時間ロスしなくていいかもしれない
      https://github.com/seigot/burger_war_dev/blob/main/STARTUP_GUIDE.md

## dockerの中に必要なライブラリを全ていれる
docker exec -it burger-war-dev bash
(必要なライブラリを全ていれる。本来はDockerfileからインストールしたいのでここの手順は暫定)
exit

## AgentServerを起動
bash commands/kit.sh -c "pip install -r src/burger_war_dev/requirements.txt"
bash commands/kit.sh -c "python src/burger_war_dev/burger_war_dev/scripts/dqn_learning.py"

## autotestを実行(sim_with_judge.sh, start.sh を繰り返し起動)
bash commands/kit.sh -s ../autotest/autotest.sh

処理の内容を補足する

構成


20211114_burger_war_dev.patchは後述

diff --git a/burger_war_dev/launch/your_burger.launch b/burger_war_dev/launch/your_burger.launch
index 362d4e9..341694e 100644
--- a/burger_war_dev/launch/your_burger.launch
+++ b/burger_war_dev/launch/your_burger.launch
@@ -7,17 +7,17 @@
 
     <include file="$(find burger_navigation)/launch/burger_navigation.launch" />
 
-    <group if="$(eval self_play==0)"> 
+    <group if="$(eval self_play==1)"> 
       <node pkg="burger_war_dev" type="dqn_operation.py" name="DQNRun" output="screen">
         <param name="side" value="$(arg side)"/>
       </node>
     </group>
 
-    <group if="$(eval self_play==1)"> 
+    <group if="$(eval self_play==0)"> 
       <node pkg="burger_war_dev" type="dqn_self_play.py" name="DQNRun" output="screen">
         <param name="side" value="$(arg side)"/>
       </node>
     </group>
 
     <!-- End of your space  -->
-</launch>
\ No newline at end of file
+</launch>
diff --git a/burger_war_dev/scripts/agents/agent.py b/burger_war_dev/scripts/agents/agent.py
index 5801616..842a703 100644
--- a/burger_war_dev/scripts/agents/agent.py
+++ b/burger_war_dev/scripts/agents/agent.py
@@ -3,6 +3,7 @@
 
 from brain import Brain
 
+print("agent.pyを実行する")
 
 class Agent:
     """
@@ -16,6 +17,7 @@ class Agent:
             capacity (int): capacity of memory
             gamma (int): discount rate
         """
+	print("Agentの初期化を行う")
         self.brain = Brain(num_actions, batch_size, capacity, gamma, prioritized, lr)  # エージェントが行動を決定するための頭脳を生成
 
     def update_policy_network(self):
@@ -24,6 +26,7 @@ class Agent:
         Args:
             
         """
+	print("Policy network modelの更新")
         self.brain.replay()
 
     def get_action(self, state, episode, policy_mode, debug):
@@ -35,9 +38,10 @@ class Agent:
         Return:
             action (Tensor): action (number)
         """
+	print("Actionを取得する")
         action = self.brain.decide_action(state, episode, policy_mode, debug)
         return action
-
+	
     def memorize(self, state, action, state_next, reward):
         """
         memorize current state, action, next state and reward
@@ -47,6 +51,7 @@ class Agent:
             state_next (dict): next state
             reward (int): reward
         """
+	print("現在の状態、行動、次の状態、報酬を記憶する")
         self.brain.memory.push(state, action, state_next, reward)
 
     def save_model(self, path):
@@ -55,6 +60,7 @@ class Agent:
         Args:
             path (str): path to save
         """
+	print("Modelを保存する")
         self.brain.save_model(path)
 
     def load_model(self, path):
@@ -63,6 +69,7 @@ class Agent:
         Args:
             path (str): path to load
         """
+	print("Modelをloadする")
         self.brain.load_model(path)
 
     def save_memory(self, path):
@@ -71,6 +78,7 @@ class Agent:
         Args:
             path (str): path to save
         """
+	print("Save Memory")
         self.brain.save_memory(path)
 
     def load_memory(self, path):
@@ -79,16 +87,19 @@ class Agent:
         Args:
             path (str): path to load
         """
+	print("Load Memory")
         self.brain.load_memory(path)
 
     def update_target_network(self):
         """
         update target network model
         """
+	print("Target neweork modelを更新する")
         self.brain.update_target_network()
     
     def detach(self):
         """
         detach agent (for server-client implementation)
         """
+	print("エージェントのデタッチ(サーバークライアント実装用)")
         pass
diff --git a/burger_war_dev/scripts/agents/brain.py b/burger_war_dev/scripts/agents/brain.py
index e3d3892..97742b5 100644
--- a/burger_war_dev/scripts/agents/brain.py
+++ b/burger_war_dev/scripts/agents/brain.py
@@ -21,16 +21,19 @@ from utils.permemory import PERMemory
 from networks.maskNet import MaskNet
 import pickle
 
+print("brain.pyを実行する")
 #------------------------------------------------
 
 class Brain:
     TARGET_UPDATE = 10
     def __init__(self, num_actions, batch_size=32, capacity=10000, gamma=0.99, prioritized=True, lr=0.0005):
+	print("Brainの初期化")
         self.batch_size = batch_size
         self.gamma = gamma
         self.num_actions = num_actions
         self.prioritized = prioritized
 
+
         # Instantiate memory object
         if self.prioritized:
             print('* Prioritized Experience Replay Mode')
@@ -46,8 +49,9 @@ class Brain:
 
         # Set device type; GPU or CPU (Use GPU if available)
         self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
-        #self.device = torch.device('cpu')
+        #self.device = torch.device('cpu') #temporary setting
         self.policy_net = self.policy_net.to(self.device)
+
         self.target_net = self.target_net.to(self.device)
 
         print('using device:', self.device)
@@ -58,7 +62,7 @@ class Brain:
 
     def replay(self):
         """Experience Replayでネットワークの重みを学習 """
-
+	print("Experience Replayでネットワークの重みを学習 ")
         # Do nothing while size of memory is lower than batch size
         if len(self.memory) < self.batch_size:
             return
@@ -157,7 +161,6 @@ class Brain:
     def decide_action(self, state, episode, policy_mode="epsilon", debug=True):
         """
         policy
-
         Args:
             state (State): state object
             episode (int): current episode
@@ -169,7 +172,7 @@ class Brain:
 
         if not debug:
             self.policy_net.eval()  # ネットワークを推論モードに切り替える
-
+            print("推論を行う")
             # Set device type; GPU or CPU
             input_pose = Variable(state.pose).to(self.device)
             input_lidar = Variable(state.lidar).to(self.device)
@@ -185,6 +188,7 @@ class Brain:
         if policy_mode == "epsilon":
             # ε-greedy法で徐々に最適行動のみを採用する
             # epsilon = 0.5 * (1 / (episode + 1))
+            print("ε-greedy法を採用する")
             if episode < 50:
                 epsilon = 0.25
             elif episode < 100:
@@ -194,7 +198,6 @@ class Brain:
 
             if epsilon <= np.random.uniform(0, 1):
                 self.policy_net.eval()  # ネットワークを推論モードに切り替える
-
                 # Set device type; GPU or CPU
                 input_pose = Variable(state.pose).to(self.device)
                 input_lidar = Variable(state.lidar).to(self.device)
@@ -213,6 +216,7 @@ class Brain:
                 print("Random action: {}".format(action.item()))
 
         elif policy_mode == "boltzmann":
+            print("boltzmann法を採用する")
             self.policy_net.eval()  # ネットワークを推論モードに切り替える
 
             # Set device type; GPU or CPU
diff --git a/burger_war_dev/scripts/agents/connection.py b/burger_war_dev/scripts/agents/connection.py
index 7c26738..31a7fd1 100644
--- a/burger_war_dev/scripts/agents/connection.py
+++ b/burger_war_dev/scripts/agents/connection.py
@@ -31,7 +31,6 @@ import queue
 import select
 import multiprocessing as mp
 
-
 def send_recv(conn, sdata):
     conn.send(sdata)
     rdata = conn.recv()
diff --git a/burger_war_dev/scripts/dqn_operation.py b/burger_war_dev/scripts/dqn_operation.py
index 179ed62..f370387 100755
--- a/burger_war_dev/scripts/dqn_operation.py
+++ b/burger_war_dev/scripts/dqn_operation.py
@@ -32,6 +32,8 @@ from utils.lidar_transform import lidar_transform
 
 import cv2
 
+print("dqn_operation.pyを実行")
+
 # config
 FIELD_SCALE = 2.4
 FIELD_MARKERS = [
@@ -45,7 +47,6 @@ ROBOT_MARKERS = {
 
 JUDGE_URL = ""
 
-
 # functions
 def send_to_judge(url, data):
     res = requests.post(url,
@@ -465,7 +466,7 @@ class DQNBot:
 
     
 if __name__ == "__main__":
-
+    print("dqn_operation.py_mainを開始")
     rospy.init_node('dqn_run')
     JUDGE_URL = rospy.get_param('/send_id_to_judge/judge_url')
 
diff --git a/burger_war_dev/scripts/dqn_self_play.py b/burger_war_dev/scripts/dqn_self_play.py
index 5427ffb..cef50ee 100755
--- a/burger_war_dev/scripts/dqn_self_play.py
+++ b/burger_war_dev/scripts/dqn_self_play.py
@@ -31,6 +31,7 @@ from utils.wallAvoid import punish_by_count, punish_by_min_dist, manual_avoid_wa
 from utils.lidar_transform import lidar_transform
 from agents.agent import Agent
 
+print("dqn_self_play.pyを実行")
 
 # config
 FIELD_SCALE = 2.4
@@ -472,7 +473,7 @@ class DQNBot:
 
     
 if __name__ == "__main__":
-
+    print("dqn_self_play_main_.pyを実行")
     rospy.init_node('dqn_run')
     JUDGE_URL = rospy.get_param('/send_id_to_judge/judge_url')
 
@@ -487,7 +488,7 @@ if __name__ == "__main__":
     print("name: {}, server: {}".format(ROBOT_NAME, JUDGE_URL))
 
     # parameters
-
+    print("parametersの設定")
     ONLINE = True
     POLICY = "epsilon"
     DEBUG = False
@@ -496,10 +497,12 @@ if __name__ == "__main__":
     MANUAL_AVOID = False
 
     # wall avoidance
+    print("wall avoidanceの設定")
     DIST_TO_WALL_TH = 0.18
     NUM_LASER_CLOSE_TO_WALL_TH = 30
 
     # action lists
+    print("action lists設定")
     VEL = 0.2
     OMEGA = 30 * 3.14/180
     ACTION_LIST = [
@@ -511,6 +514,7 @@ if __name__ == "__main__":
     ]
 
     # agent config
+    print("agent config設定")
     UPDATE_Q_FREQ = 10
     BATCH_SIZE = 16
     MEM_CAPACITY = 2000
diff --git a/burger_war_dev/scripts/networks/net.py b/burger_war_dev/scripts/networks/net.py
index 85e6397..6eb310a 100755
--- a/burger_war_dev/scripts/networks/net.py
+++ b/burger_war_dev/scripts/networks/net.py
@@ -2,6 +2,7 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
+print("net.pyを実行する")
 class Net(nn.Module):
     def __init__(self, output_size):
         """
diff --git a/burger_war_dev/scripts/networks/net2.py b/burger_war_dev/scripts/networks/net2.py
index ddb8e1c..77ccb1f 100755
--- a/burger_war_dev/scripts/networks/net2.py
+++ b/burger_war_dev/scripts/networks/net2.py
@@ -2,6 +2,7 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
+print("net2.pyを実行する")
 class Net(nn.Module):
     def __init__(self, output_size):
         """
diff --git a/burger_war_dev/scripts/utils/lidar_transform.py b/burger_war_dev/scripts/utils/lidar_transform.py
index b3cc3b4..4764312 100644
--- a/burger_war_dev/scripts/utils/lidar_transform.py
+++ b/burger_war_dev/scripts/utils/lidar_transform.py
@@ -1,6 +1,5 @@
 import numpy as np
 
-
 def lidar_transform(lidar, debug=True):
 
     if debug:
@@ -30,4 +29,4 @@ def lidar_transform(lidar, debug=True):
 
                     zero_idx = []
 
-    return lidar
\ No newline at end of file
+    return lidar
diff --git a/burger_war_dev/scripts/utils/permemory.py b/burger_war_dev/scripts/utils/permemory.py
index 86e17d3..53f5768 100644
--- a/burger_war_dev/scripts/utils/permemory.py
+++ b/burger_war_dev/scripts/utils/permemory.py
@@ -8,6 +8,8 @@ from transition import Transition
 import numpy as np
 from replaymemory import ReplayMemory
 
+print("permemory.pyを実行する")
+
 class PERMemory(ReplayMemory):
     epsilon = 0.0001
     alpha = 0.6
diff --git a/burger_war_dev/scripts/utils/random_replaymemory.py b/burger_war_dev/scripts/utils/random_replaymemory.py
index 77f12ed..46a05c4 100644
--- a/burger_war_dev/scripts/utils/random_replaymemory.py
+++ b/burger_war_dev/scripts/utils/random_replaymemory.py
@@ -6,6 +6,8 @@ from state import State
 from transition import Transition
 from replaymemory import ReplayMemory
 
+print("random_replaymemory.pyを実行する")
+
 class RandomReplayMemory(ReplayMemory):
 
     def push(self, state, action, state_next, reward):
diff --git a/burger_war_dev/scripts/utils/replaymemory.py b/burger_war_dev/scripts/utils/replaymemory.py
index db3c89c..b791c6c 100644
--- a/burger_war_dev/scripts/utils/replaymemory.py
+++ b/burger_war_dev/scripts/utils/replaymemory.py
@@ -4,6 +4,7 @@
 from state import State
 from transition import Transition
 
+print("replaymemory.pyを実行する")
 class ReplayMemory(object):
 
     def __init__(self, CAPACITY):
diff --git a/burger_war_dev/scripts/utils/state.py b/burger_war_dev/scripts/utils/state.py
index bc0c5bb..fd13776 100755
--- a/burger_war_dev/scripts/utils/state.py
+++ b/burger_war_dev/scripts/utils/state.py
@@ -3,6 +3,8 @@
 
 from collections import namedtuple
 
+print("state.pyを実行する")
+
 State = namedtuple (
     'State', ('pose', 'lidar', 'image', 'mask')
-)
\ No newline at end of file
+)
diff --git a/burger_war_dev/scripts/utils/sumtree.py b/burger_war_dev/scripts/utils/sumtree.py
index 221a03d..f6ee926 100755
--- a/burger_war_dev/scripts/utils/sumtree.py
+++ b/burger_war_dev/scripts/utils/sumtree.py
@@ -7,6 +7,8 @@
 # and add some functions
 import numpy
 
+print("sumtree.pyを実行する")
+
 class SumTree:
     write = 0
 
@@ -62,4 +64,4 @@ class SumTree:
         idx = self._retrieve(0, s)
         dataIdx = idx - self.capacity + 1
 
-        return (idx, self.tree[idx], self.data[dataIdx])
\ No newline at end of file
+        return (idx, self.tree[idx], self.data[dataIdx])
diff --git a/burger_war_dev/scripts/utils/transition.py b/burger_war_dev/scripts/utils/transition.py
index 5ad4d2d..e3b2595 100755
--- a/burger_war_dev/scripts/utils/transition.py
+++ b/burger_war_dev/scripts/utils/transition.py
@@ -3,6 +3,8 @@
 
 from collections import namedtuple
 
+print("transition.pyを実行する")
+
 Transition = namedtuple(
     'Transition', ('state', 'action', 'next_state', 'reward')
-)
\ No newline at end of file
+)
diff --git a/burger_war_dev/scripts/utils/wallAvoid.py b/burger_war_dev/scripts/utils/wallAvoid.py
index 8513bba..af75118 100644
--- a/burger_war_dev/scripts/utils/wallAvoid.py
+++ b/burger_war_dev/scripts/utils/wallAvoid.py
@@ -3,6 +3,7 @@
 
 import torch
 
+print("wallAvoid.pyを実行する")
 
 def punish_by_count(lidar, dist_th=0.2, count_th=90):
     # Check LiDAR data to punish for AMCL failure
@@ -18,7 +19,7 @@ def punish_by_count(lidar, dist_th=0.2, count_th=90):
 
         # Punish if too many lasers close to obstacle
         if count_too_close > count_th:
-            print("### Too close to the wall, get penalty ###")
+            print("### Too close to the wall, get penalty1 ###")
             punish = -0.5
 
     return punish
@@ -30,7 +31,7 @@ def punish_by_min_dist(lidar, dist_th=0.15):
     if lidar is not None:
         lidar_1d = lidar.squeeze()
         if lidar_1d.min() < dist_th:
-            print("### Too close to the wall, get penalty ###")
+            print("### Too close to the wall, get penalty2 ###")
             punish = -0.5
 
     return punish
@@ -58,6 +59,7 @@ def manual_avoid_wall(lidar, dist_th=0.2, count_th=90):
         ], key=lambda e: e[0])
             
     else:
+	print("safety_for_wall")
         avoid = False
         linear_x = None
         angular_z = None

dqn_operation.py

def __init__(self, robot="r", online=False, policy_mode="epsilon", debug=True,
   各種処理や変数を初期化
  publisher/subscriberを登録

def callback_lidar(self, data)
  lidarの生データを取得
  lidarのrangeを取得

def callback_image(self, data)
  画像データを取得
  微分画像を学習用に保存している?

def callback_odom(self, data)
  odometoryを取得
  自己位置として内部変数に記憶?

def callback_amcl(self, data)
  amclを取得
  自己位置として内部変数に記憶?

def callback_warstate(self, event)
  warstateを取得
  学習用にスコア情報を記憶?

def get_reward(self, past, current)
  rewardを計算する
  自分の獲得スコア、敵の獲得スコア、位置(bad_position)の合計値をrewardの値にする

def strategy(self)
  戦略(動き方)を決定する
  epsilon/boltzmann/(avoid)のどのモードで動くかをきめる
  self.my_pose/self.lidar_ranges/self.image/self.maskから、次のtwist_x,zを求めてpublishする

def move_robot(self, model_name, position=None, orientation=None, linear=None, angular=None)
  未使用

def init_amcl_pose(self)
  amcl_poseを初期化、reset時に実施する用

def stop(self)
  ゲームを停止する
  gazeboを止める

def restart(self)
  ゲームを再開する
  stop/resetの後に呼び出したい様子

def reset(self)
  ゲームをresetする

def train(self, n_epochs=20)
  update_policy_networkする

def run(self, rospy_rate=1)
  メインループがある関数


dqn_self_play.py

  • dqn_operation.pyとほぼ同じ。学習モデルをloadして実行する。
def send_to_judge(url, data)
def __init__(self, robot="r", online=False, policy_mode="epsilon", debug=True, save_path=None, load_path=None, manual_avoid=False)
def callback_lidar(self, data)
def callback_image(self, data)
def min_max(x, axis=None)
def callback_odom(self, data)
def callback_amcl(self, data)
def callback_warstate(self, event)
def get_reward(self, past, current)
def strategy(self)
def move_robot(self, model_name, position=None, orientation=None, linear=None, angular=None)
def init_amcl_pose(self)
def stop(self)
def restart(self)
def reset(self)
def train(self, n_epochs=20)
def run(self, rospy_rate=1)

dqn_learning.py

scripts/agents/agent_conn.py にサーバ側処理の実体がある。

class AgentServer
def __init__(self, port=5010)
  AgentServerを初期化する

def _wait(self)
  AgentServerをClientからのrequest wait状態にする
  runから呼ばれる

def run(self)
  AgentServerのメインループ

class AgentClient
def __init__(self, server_address, port, num_actions, batch_size=32, capacity=10000, gamma=0.99, prioritized=True, lr=0.0005)
  AgentClientを初期化する

def update_policy_network(self)
  update policy network model

def get_action(self, state, episode, policy_mode, debug)
   サーバからactionを取得する

def memorize(self, state, action, state_next, reward)
  memorize current state, action, next state and reward
  サーバに記憶用のデータを送信する

def save_model(self, path)
  save model
  サーバに学習モデルを保存する

def load_model(self, path)
  load model
  サーバから学習モデルをロードする

def save_memory(self, path)
  save memory
  サーバへデータ送信する

def load_memory(self, path)
  load memory
  サーバからデータ取得する

def update_target_network(self)
  update target network model

def detach(self)
  agentをdetachする

Clone this wiki locally