Skip to content

Commit

Permalink
add multiprocessing to multi-robot env (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
Bonifatius94 committed Aug 15, 2023
1 parent 4d8146b commit 1be8d59
Showing 1 changed file with 27 additions and 18 deletions.
45 changes: 27 additions & 18 deletions robot_sf/robot_env.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations
from math import ceil
from typing import Tuple, Callable, List, Protocol, Any
from dataclasses import dataclass, field
from copy import deepcopy
from multiprocessing.pool import ThreadPool

import numpy as np
from gym.vector import VectorEnv
from gym import Env, spaces
from robot_sf.nav.map_config import MapDefinition
Expand Down Expand Up @@ -216,7 +217,7 @@ def __init__(
self.last_action = None
if debug:
self.sim_ui = SimulationView(
scaling=10,
scaling=6,
obstacles=map_def.obstacles,
robot_radius=env_config.robot_config.radius,
ped_radius=env_config.sim_config.ped_radius,
Expand Down Expand Up @@ -258,7 +259,7 @@ def exit(self):

class MultiRobotEnv(VectorEnv):
"""Representing an OpenAI Gym environment for training
a self-driving robot with reinforcement learning"""
multiple self-driving robots with reinforcement learning"""

def __init__(
self, env_config: EnvSettings = EnvSettings(),
Expand All @@ -268,6 +269,10 @@ def __init__(
map_def = env_config.map_pool.map_defs[0] # info: only use first map
action_space, observation_space, orig_obs_space = init_spaces(env_config, map_def)
super(MultiRobotEnv, self).__init__(num_robots, observation_space, action_space)
self.action_space = spaces.Box(
low=np.array([self.single_action_space.low for _ in range(num_robots)]),
high=np.array([self.single_action_space.high for _ in range(num_robots)]),
dtype=self.single_action_space.low.dtype)

self.reward_func, self.debug = reward_func, debug
self.simulators = init_simulators(env_config, map_def, num_robots)
Expand All @@ -281,6 +286,9 @@ def __init__(
for nav, occ, sen in zip(sim.robot_navs, occupancies, sensors)]
self.states.extend(states)

self.sim_worker_pool = ThreadPool(len(self.simulators))
self.obs_worker_pool = ThreadPool(num_robots)

def step(self, actions):
actions = [self.simulators[0].robots[0].parse_action(a) for a in actions]
i = 0
Expand All @@ -290,15 +298,15 @@ def step(self, actions):
actions_per_simulator.append(actions[i:i+num_robots])
i += num_robots

# TODO: parallelize
for sim, sim_actions in zip(self.simulators, actions_per_simulator):
sim.step_once(sim_actions)
self.sim_worker_pool.map(
lambda s_a: s_a[0].step_once(s_a[1]),
zip(self.simulators, actions_per_simulator))

# TODO: parallelize
obs = [state.step() for state in self.states]
obs = self.obs_worker_pool.map(lambda s: s.step(), self.states)

metas = [state.meta_dict() for state in self.states]
masked_metas = [{ "step": meta["step"], "meta": meta } for meta in metas]
masked_metas = (*masked_metas,)
terms = [state.is_terminal for state in self.states]
rewards = [self.reward_func(meta) for meta in metas]

Expand All @@ -307,22 +315,23 @@ def step(self, actions):
sim.reset_state()
obs[i] = state.reset()

obs = { OBS_DRIVE_STATE: [o[OBS_DRIVE_STATE] for o in obs],
OBS_RAYS: [o[OBS_RAYS] for o in obs] }
obs = { OBS_DRIVE_STATE: np.array([o[OBS_DRIVE_STATE] for o in obs]),
OBS_RAYS: np.array([o[OBS_RAYS] for o in obs])}

return obs, rewards, terms, masked_metas

def reset(self):
# TODO: parallelize
for sim in self.simulators:
sim.reset_state()
obs = [state.reset() for state in self.states]
self.sim_worker_pool.map(lambda sim: sim.reset_state(), self.simulators)
obs = self.obs_worker_pool.map(lambda s: s.reset(), self.states)

return {
OBS_DRIVE_STATE: [o[OBS_DRIVE_STATE] for o in obs],
OBS_RAYS: [o[OBS_RAYS] for o in obs],
}
obs = { OBS_DRIVE_STATE: np.array([o[OBS_DRIVE_STATE] for o in obs]),
OBS_RAYS: np.array([o[OBS_RAYS] for o in obs]) }
return obs

def render(self, robot_id: int=0):
# TODO: add support for PyGame rendering
pass

def close_extras(self, **kwargs):
self.sim_worker_pool.close()
self.obs_worker_pool.close()

0 comments on commit 1be8d59

Please sign in to comment.