-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtestairenv.cpp
132 lines (106 loc) · 3.61 KB
/
testairenv.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
/*
* testairenv.cpp
*
* Created on: Apr 5, 2021
* Author: zf
*/
#include <gymtest/env/airenv.h>
#include <log4cxx/logger.h>
#include <log4cxx/basicconfigurator.h>
namespace
{
log4cxx::LoggerPtr logger(log4cxx::Logger::getLogger("testgym"));
}
namespace {
//BeamRider: 9
void testGetInfo(std::string serverAddr) {
const int clientNum = 2;
// std::string serverAddr = "tcp://127.0.0.1:10201";
LOG4CXX_INFO(logger, "To connect to " << serverAddr);
// AirEnv env(serverAddr, "SpaceInvaders-v0", clientNum);
// AirEnv env(serverAddr, "Pong-v0", clientNum);
//Qbert = 6
//Pacman = 9
//Alien = 18, 3 lives
//Assault = 7, 4 lives
AirEnv env(serverAddr, "AssaultNoFrameskip-v4", clientNum);
auto info = env.init();
auto actionSpace = std::get<1>(info);
auto obSpace = std::get<0>(info);
LOG4CXX_INFO(logger, "Action space: " << actionSpace.type << ", " << actionSpace.shape);
LOG4CXX_INFO(logger, "Observation space:" << obSpace.type << "-" << obSpace.shape);
auto rc = env.reset();
LOG4CXX_INFO(logger, "next state: " << rc.size());
auto actions = std::vector<long>(clientNum, 2);
auto stepResult = env.step(1);
auto obsvVec = std::get<0>(stepResult);
auto rewardVec = std::get<1>(stepResult);
auto doneVec = std::get<2>(stepResult);
LOG4CXX_INFO(logger, "obsvVec: " << obsvVec.size());
LOG4CXX_INFO(logger, "reward: " << rewardVec);
LOG4CXX_INFO(logger, "done: " << doneVec);
}
void testEpisode() {
const int clientNum = 2;
std::string serverAddr = "tcp://127.0.0.1:10201";
LOG4CXX_INFO(logger, "To connect to " << serverAddr);
AirEnv env(serverAddr, "PongNoFrameskip-v4", clientNum);
auto info = env.init();
auto actionSpace = std::get<1>(info);
auto obSpace = std::get<0>(info);
LOG4CXX_INFO(logger, "Action space: " << actionSpace.type << ", " << actionSpace.shape);
LOG4CXX_INFO(logger, "Observation space:" << obSpace.type << "-" << obSpace.shape);
bool isDone = false;
auto obsv = env.reset();
while (!isDone) {
auto actions = std::vector<long>(clientNum, 3);
auto stepResult = env.step(actions, true);
obsv = std::get<0>(stepResult);
auto rewardVec = std::get<1>(stepResult);
auto doneVec = std::get<2>(stepResult);
LOG4CXX_INFO(logger, "reward: " << rewardVec);
isDone = false;
for (const auto &done: doneVec) {
if (done) {
isDone = true;
break;
}
}
}
}
void testReset() {
const int clientNum = 2;
std::string serverAddr = "tcp://127.0.0.1:10201";
LOG4CXX_INFO(logger, "To connect to " << serverAddr);
AirEnv env(serverAddr, "Alien-v0", clientNum);
auto info = env.init();
auto actionSpace = std::get<1>(info);
auto obSpace = std::get<0>(info);
LOG4CXX_INFO(logger, "Action space: " << actionSpace.type << ", " << actionSpace.shape);
LOG4CXX_INFO(logger, "Observation space:" << obSpace.type << "-" << obSpace.shape);
bool isDone = false;
auto obsv = env.reset();
LOG4CXX_INFO(logger, "reset " << obsv.size());
while (!isDone) {
auto actions = std::vector<long>(clientNum, 3);
auto stepResult = env.step(actions);
obsv = std::get<0>(stepResult);
auto rewardVec = std::get<1>(stepResult);
auto doneVec = std::get<2>(stepResult);
// LOG4CXX_INFO(logger, "reward: " << rewardVec);
for (int i = 0; i < doneVec.size(); i ++) {
if (doneVec[i]) {
auto tmpObs = env.reset(i);
LOG4CXX_INFO(logger, "Reset client " << i << "result: " << tmpObs.size());
}
}
}
}
}
int main(int argc, char** argv) {
log4cxx::BasicConfigurator::configure();
testGetInfo(argv[1]);
// testEpisode();
// testReset();
LOG4CXX_INFO(logger, "End of test");
}