-
Notifications
You must be signed in to change notification settings - Fork 3
/
test_agent.py
120 lines (95 loc) · 3.75 KB
/
test_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import argparse
from pathlib import Path
import time
import cv2
import hydra
import numpy as np
import torch
import yaml
import os
import json
import random
import pickle as pkl
from robobuf import ReplayBuffer as RB
from data4robotics.transforms import get_transform_by_name
# constants for data loading
BUF_SHUFFLE_RNG = 3904767649 # from replay_buffer.py
n_test_trans = 500 # usually hardocded in task/franka.yaml
class BaselinePolicy:
def __init__(self, agent_path, model_name):
with open(Path(agent_path, "agent_config.yaml"), "r") as f:
config_yaml = f.read()
agent_config = yaml.safe_load(config_yaml)
with open(Path(agent_path, "exp_config.yaml"), "r") as f:
config_yaml = f.read()
exp_config = yaml.safe_load(config_yaml)
self.cam_idx = exp_config['params']['task']['train_buffer']['cam_idx']
agent = hydra.utils.instantiate(agent_config)
save_dict = torch.load(Path(agent_path, model_name), map_location="cpu")
agent.load_state_dict(save_dict['model'])
self.agent = agent.eval().cuda()
self.transform = get_transform_by_name('preproc')
def _proc_image(self, rgb_img, size=(256,256)):
rgb_img = cv2.resize(rgb_img, size, interpolation=cv2.INTER_AREA)
rgb_img = torch.from_numpy(rgb_img).float().permute((2, 0, 1)) / 255
return self.transform(rgb_img)[None].cuda()
def forward(self, img, obs):
img = self._proc_image(img)
state = torch.from_numpy(obs)[None].float().cuda()
with torch.no_grad():
ac = self.agent.get_actions(img, state)
ac = ac[0].cpu().numpy().astype(np.float32)
return ac
@property
def ac_chunk(self):
return self.agent.ac_chunk
def _get_data(idx, buf, ac_chunk, cam_idx):
t = buf[idx]
loop_t, chunked_actions = t, []
for _ in range(ac_chunk):
if loop_t.next is None:
break
chunked_actions.append(loop_t.action[None])
loop_t = loop_t.next
if len(chunked_actions) != ac_chunk:
raise ValueError
i_t, o_t = t.obs.image(cam_idx), t.obs.state
i_t_prime, o_t_prime = t.next.obs.image(cam_idx), t.next.obs.state
a_t = np.concatenate(chunked_actions, 0)
return i_t, o_t, a_t
def main():
parser = argparse.ArgumentParser()
parser.add_argument("checkpoint")
parser.add_argument("--buffer_path", default='/scratch/sudeep/toaster3/buf.pkl')
args = parser.parse_args()
agent_path = os.path.expanduser(os.path.dirname(args.checkpoint))
model_name = args.checkpoint.split('/')[-1]
policy = BaselinePolicy(agent_path, model_name)
# build data loader
cam_idx = policy.cam_idx
print('cam_idx:', cam_idx)
with open(args.buffer_path, 'rb') as f:
buf = RB.load_traj_list(pkl.load(f))
# shuffle the list with the fixed seed
rng = random.Random(BUF_SHUFFLE_RNG)
# get and shuffle list of buf indices, and get test data
index_list = list(range(len(buf)))
rng.shuffle(index_list)
index_list = index_list[:n_test_trans]
l2s, lsigs = [], []
for idx in index_list[:50]:
i_t, o_t, a_t = _get_data(idx, buf, policy.ac_chunk, cam_idx)
pred_ac = policy.forward(i_t, o_t)
# calculate deltas
l2 = np.linalg.norm(a_t - pred_ac)
lsign = np.sum(np.logical_or(np.logical_and(a_t > 0, pred_ac <= 0),
np.logical_and(a_t <= 0, pred_ac > 0)))
l2s.append(l2); lsigs.append(lsign)
print('\n')
print('a_t', a_t)
print('pred_ac', pred_ac)
print(f'losses: l2={l2:0.2f}\tlsign={lsign}')
print('\n')
print(f'avg losses: l2={np.mean(l2s):0.3f}\tlsign={np.mean(lsigs):0.3f}')
if __name__ == "__main__":
main()