-
Notifications
You must be signed in to change notification settings - Fork 8
/
12_wrappers_tasks.py
131 lines (115 loc) · 4.23 KB
/
12_wrappers_tasks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import logging
import gymnasium
import numpy as np
from mats_gym.envs import renderers
from mats_gym.tasks.tasks import TaskCombination
from mats_gym.tasks.traffic_event_tasks import (
InfractionAvoidanceTask,
RouteFollowingTask,
)
import cv2
from srunner.scenariomanager.traffic_events import TrafficEventType
from mats_gym.wrappers.birdseye_view.birdseye import BirdViewProducer
from mats_gym.wrappers.birdview import BirdViewObservationWrapper, ObservationConfig
from mats_gym.wrappers.meta_actions_wrapper import MetaActionWrapper
from mats_gym.wrappers.task_wrapper import TaskWrapper
import mats_gym
"""
This example shows how to use the TaskWrapper class conveniently define new tasks.
A task defines how the reward and termination condition of an agent is defined.
Tasks can be combined using the TaskCombination class.
"""
NUM_EPISODES = 3
def policy():
"""
A simple policy that drives the agent forward and turns left or right randomly.
"""
return np.array(
[
0.5 + np.random.rand() / 2, # throttle
np.random.rand() - 0.5, # steer
0.0, # brake
]
)
def show_obs(obs, agent):
"""
Displays the birdview observation of the given agent. The layers are collapsed into a single RGB image.
"""
img = obs[agent]["birdview"]
obs = BirdViewProducer.as_rgb(img)
cv2.imwrite("img.png", obs)
cv2.waitKey(10)
def main():
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(filename)s - [%(levelname)s] - %(message)s",
)
env = mats_gym.scenic_env(
host="localhost",
port=2000,
scenario_specification="scenarios/scenic/four_way_route_scenario.scenic",
scenes_per_scenario=2,
resample_scenes=False,
agent_name_prefixes=["vehicle"],
render_mode="human",
render_config=renderers.camera_pov(agent="vehicle_0"),
)
env = BirdViewObservationWrapper(env=env)
tasks = {}
for agent in env.agents:
task = TaskCombination(
agent=agent,
tasks=[
RouteFollowingTask(agent=agent),
InfractionAvoidanceTask(
agent=agent,
infractions=[
TrafficEventType.COLLISION_VEHICLE.name,
TrafficEventType.COLLISION_STATIC.name,
TrafficEventType.COLLISION_PEDESTRIAN.name,
TrafficEventType.ON_SIDEWALK_INFRACTION.name,
TrafficEventType.ROUTE_DEVIATION.name,
TrafficEventType.OUTSIDE_LANE_INFRACTION.name,
TrafficEventType.WRONG_WAY_INFRACTION.name,
TrafficEventType.TRAFFIC_LIGHT_INFRACTION.name,
TrafficEventType.STOP_INFRACTION.name,
],
),
],
weights=[0.01, 1.0],
)
tasks[agent] = task
env = TaskWrapper(
env=env,
tasks=tasks,
ignore_wrapped_env_reward=True,
ignore_wrapped_env_termination=False,
)
for _ in range(NUM_EPISODES):
obs, info = env.reset()
done = False
rewards = {agent: 0.0 for agent in env.agents}
while not done:
actions = {agent: policy() for agent in env.agents}
obs, reward, done, truncated, info = env.step(actions)
show_obs(obs, "vehicle_0")
for agent, reward in reward.items():
rewards[agent] += reward
print(
f"Cum. Rewards: {', '.join([f'{agent}={reward}' for agent, reward in rewards.items()])}"
)
done = all(done.values())
env.render()
for agent in env.agents:
print(f"Agent {agent}: reward={rewards[agent]}, events:")
for event in info[agent]["events"]:
text = f" - {event['event']} at frame {event.get('frame', 'N/A')}"
if (
event["event"] == TrafficEventType.ROUTE_COMPLETION.name
and "route_completed" in event
):
text += f" completion={event['route_completed']:.2f}"
print(text)
env.close()
if __name__ == "__main__":
main()