-
Notifications
You must be signed in to change notification settings - Fork 3
/
evaluate.py
64 lines (58 loc) · 2.37 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import tyro
from nudge.evaluator import Evaluator
from nudge.evaluator_neuralppo import EvaluatorNeuralPPO
from dataclasses import dataclass
import tyro
# @dataclass
# class Args:
# env_name: str = "kangaroo"
# """name of the environment"""
# # agent_path: str = "out/runs/kangaroo_softmax_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.0_numenvs_60_steps_128_pretrained_False_joint_True_20"
# agent_path: str = "out/runs/{}_best".format(env_name)
# """path for the agent to be loaded"""
# fps: int = 5
# """frames per second"""
def main(
env_name: str = "seaquest",
agent_path: str = "out/runs/kangaroo_softmax_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.0_numenvs_60_steps_128_pretrained_False_joint_True_20",
fps: int = 5,
episodes: int = 2,
model: str = 'blendrl',
device: str = 'cuda:0'
) -> None:
assert model in ['blendrl', 'neuralppo'], "Invalid model type; choose from ['blendrl', 'neuralppo']"
if model == 'blendrl':
evaluator = Evaluator(\
episodes=episodes,
agent_path=agent_path,
env_name=env_name,
fps=fps,
deterministic=False,
device=device,
# env_kwargs=dict(render_oc_overlay=True),
env_kwargs=dict(render_oc_overlay=False),
render_predicate_probs=True)
elif model == 'neuralppo':
evaluator = EvaluatorNeuralPPO(\
episodes=episodes,
agent_path=agent_path,
env_name=env_name,
fps=fps,
deterministic=False,
device=device,
# env_kwargs=dict(render_oc_overlay=True),
env_kwargs=dict(render_oc_overlay=False),
render_predicate_probs=True)
evaluator.run()
if __name__ == "__main__":
tyro.cli(main)
# args = tyro.cli(Args)
# renderer = Renderer(\
# agent_path="out/runs/kangaroo_softmax_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.0_numenvs_60_steps_128_pretrained_False_joint_True_20",
# # agent_path="out/runs/kangaroo_softmax_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_60_steps_128_pretrained_False_joint_True_50",
# env_name="kangaroo",
# fps=5,
# deterministic=False,
# env_kwargs=dict(render_oc_overlay=True),
# render_predicate_probs=True)
# renderer.run()