diff --git a/robot_sf/gym_env/env_util.py b/robot_sf/gym_env/env_util.py index c06033d..ecde2ac 100644 --- a/robot_sf/gym_env/env_util.py +++ b/robot_sf/gym_env/env_util.py @@ -2,6 +2,7 @@ env_util """ from typing import List, Union +from enum import Enum from gymnasium import spaces import numpy as np @@ -14,6 +15,10 @@ from robot_sf.sensor.sensor_fusion import fused_sensor_space, SensorFusion from robot_sf.sim.simulator import Simulator, PedSimulator +class AgentType(Enum): + ROBOT = 1 + PEDESTRIAN = 2 + def init_collision_and_sensors( sim: Simulator, env_config: EnvSettings, @@ -91,18 +96,21 @@ def init_spaces(env_config: EnvSettings, map_def: MapDefinition): A tuple containing the action space, the extended observation space, and the original observation space of the robot. """ - action_space, obs_space, orig_obs_space = create_spaces(env_config, map_def, create_robot=True) + action_space, obs_space, orig_obs_space = create_spaces(env_config, map_def, agent_type=AgentType.ROBOT) # Return the action space, the extended observation space, and the original # observation space return action_space, obs_space, orig_obs_space def create_spaces(env_config: Union[EnvSettings, PedEnvSettings], map_def: MapDefinition, - create_robot: bool = True): + agent_type: AgentType = AgentType.ROBOT): # Create a agent using the factory method in the environment configuration - if create_robot: + if agent_type == AgentType.ROBOT: agent = env_config.robot_factory() - else: + elif agent_type == AgentType.PEDESTRIAN: agent = env_config.pedestrian_factory() + else: + raise ValueError(f"Unsupported agent type: {agent_type}") + # Get the action space from the agent action_space = agent.action_space @@ -142,8 +150,8 @@ def init_ped_spaces(env_config: PedEnvSettings, map_def: MapDefinition): A tuple containing a list of action space, the extended observation space, and the original observation space of the robot and the pedestrian. """ - action_space_robot, obs_space_robot, orig_obs_space_robot = create_spaces(env_config, map_def, create_robot=True) - action_space_ped, obs_space_ped, orig_obs_space_ped = create_spaces(env_config, map_def, create_robot=False) + action_space_robot, obs_space_robot, orig_obs_space_robot = create_spaces(env_config, map_def, agent_type=AgentType.ROBOT) + action_space_ped, obs_space_ped, orig_obs_space_ped = create_spaces(env_config, map_def, agent_type=AgentType.PEDESTRIAN) # As a list [robot, pedestrian] return [action_space_robot, action_space_ped], [obs_space_robot, obs_space_ped], [orig_obs_space_robot, orig_obs_space_ped] diff --git a/robot_sf/gym_env/pedestrian_env.py b/robot_sf/gym_env/pedestrian_env.py index 8a70607..4a581d2 100644 --- a/robot_sf/gym_env/pedestrian_env.py +++ b/robot_sf/gym_env/pedestrian_env.py @@ -195,8 +195,8 @@ def step(self, action): # if recording is enabled, record the state if self.recording_enabled: self.record() - - return obs_ped, reward, term, False,{"step": meta["step"], "meta": meta} + truncated = False + return obs_ped, reward, term, truncated, {"step": meta["step"], "meta": meta} def reset(self, seed=None, options=None): """ diff --git a/robot_sf/render/sim_view.py b/robot_sf/render/sim_view.py index 433c124..424dde9 100644 --- a/robot_sf/render/sim_view.py +++ b/robot_sf/render/sim_view.py @@ -367,8 +367,8 @@ def _augment_lidar(self, ray_vecs: np.ndarray): def _augment_action(self, action: VisualizableAction, color): r_x, r_y = action.pose[0] # scale vector length to be always visible - vec_length = action.robot_action[0] * self.scaling - vec_orient = action.robot_pose[1] + vec_length = action.action[0] * self.scaling + vec_orient = action.pose[1] def from_polar(length: float, orient: float) -> Vec2D: return cos(orient) * length, sin(orient) * length diff --git a/robot_sf/sim/simulator.py b/robot_sf/sim/simulator.py index 183190c..0d4a68d 100644 --- a/robot_sf/sim/simulator.py +++ b/robot_sf/sim/simulator.py @@ -239,7 +239,7 @@ def get_proximity_point(self, fixed_point: Tuple[float, float], if not self.is_obstacle_collision(new_x, new_y): return new_x, new_y - logger.warning("Could not find a valid proximity point: {fixed_point}.") + logger.warning(f"Could not find a valid proximity point: {fixed_point}.") spawn_id = sample(self.map_def.ped_spawn_zones, k=1)[0] # Spawn in pedestrian spawn_zone initial_spawn = sample_zone(spawn_id, 1)[0] return initial_spawn