-
Notifications
You must be signed in to change notification settings - Fork 1
/
test.py
43 lines (37 loc) · 1.46 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import gym
import numpy as np
from cartpole_controller import LocalLinearizationController
from env.cartpole_control_env import CartPoleControlEnv
init_states = [np.array([0.0, 0.0, 0.0, 0.0]),
np.array([0.0, 0.0, 0.2, 0.0]),
np.array([0.0, 0.0, 0.4, 0.0]),
np.array([0.0, 0.0, 0.6, 0.0]),
np.array([0.0, 0.0, 0.8, 0.0]),
np.array([0.0, 0.0, 1.0, 0.0]),
np.array([0.0, 0.0, 1.2, 0.0]),
np.array([0.0, 0.0, 1.4, 0.0])]
def test(init_state, x_s, u_s, T=500, num_episodes=100):
env = gym.make("env:CartPoleControlEnv-v0")
controller = LocalLinearizationController(env)
policies = controller.compute_local_policy(x_s, u_s, T)
# For testing, we use a noisy environment which adds small Gaussian noise to
# state transition. Your controller only need to consider the env without noise.
env = gym.make("env:NoisyCartPoleControlEnv-v0")
total_cost = 0
for _ in range(num_episodes):
observation = env.reset(state=init_state)
for (K,k) in policies:
action = (K @ observation + k)
observation, cost, done, info = env.step(action)
total_cost += cost
if done:
break
env.close()
return total_cost / num_episodes
def main():
x_s = np.array([0, 0, 0, 0])
u_s = np.array([0])
for i, s in enumerate(init_states):
print("case {} avergae cost:".format(i), test(s, x_s, u_s))
if __name__ == "__main__":
main()