Skip to content

Commit

Permalink
Refactor with help from CodeRabbitAI
Browse files Browse the repository at this point in the history
  • Loading branch information
JuliusMiller committed Sep 25, 2024
1 parent 5a1dda1 commit fa52769
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 11 deletions.
20 changes: 14 additions & 6 deletions robot_sf/gym_env/env_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
env_util
"""
from typing import List, Union
from enum import Enum

from gymnasium import spaces
import numpy as np
Expand All @@ -14,6 +15,10 @@
from robot_sf.sensor.sensor_fusion import fused_sensor_space, SensorFusion
from robot_sf.sim.simulator import Simulator, PedSimulator

class AgentType(Enum):
ROBOT = 1
PEDESTRIAN = 2

def init_collision_and_sensors(
sim: Simulator,
env_config: EnvSettings,
Expand Down Expand Up @@ -91,18 +96,21 @@ def init_spaces(env_config: EnvSettings, map_def: MapDefinition):
A tuple containing the action space, the extended observation space, and
the original observation space of the robot.
"""
action_space, obs_space, orig_obs_space = create_spaces(env_config, map_def, create_robot=True)
action_space, obs_space, orig_obs_space = create_spaces(env_config, map_def, agent_type=AgentType.ROBOT)
# Return the action space, the extended observation space, and the original
# observation space
return action_space, obs_space, orig_obs_space

def create_spaces(env_config: Union[EnvSettings, PedEnvSettings], map_def: MapDefinition,
create_robot: bool = True):
agent_type: AgentType = AgentType.ROBOT):
# Create a agent using the factory method in the environment configuration
if create_robot:
if agent_type == AgentType.ROBOT:
agent = env_config.robot_factory()
else:
elif agent_type == AgentType.PEDESTRIAN:
agent = env_config.pedestrian_factory()
else:
raise ValueError(f"Unsupported agent type: {agent_type}")


# Get the action space from the agent
action_space = agent.action_space
Expand Down Expand Up @@ -142,8 +150,8 @@ def init_ped_spaces(env_config: PedEnvSettings, map_def: MapDefinition):
A tuple containing a list of action space, the extended observation space, and
the original observation space of the robot and the pedestrian.
"""
action_space_robot, obs_space_robot, orig_obs_space_robot = create_spaces(env_config, map_def, create_robot=True)
action_space_ped, obs_space_ped, orig_obs_space_ped = create_spaces(env_config, map_def, create_robot=False)
action_space_robot, obs_space_robot, orig_obs_space_robot = create_spaces(env_config, map_def, agent_type=AgentType.ROBOT)
action_space_ped, obs_space_ped, orig_obs_space_ped = create_spaces(env_config, map_def, agent_type=AgentType.PEDESTRIAN)

# As a list [robot, pedestrian]
return [action_space_robot, action_space_ped], [obs_space_robot, obs_space_ped], [orig_obs_space_robot, orig_obs_space_ped]
Expand Down
4 changes: 2 additions & 2 deletions robot_sf/gym_env/pedestrian_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def step(self, action):
# if recording is enabled, record the state
if self.recording_enabled:
self.record()

return obs_ped, reward, term, False,{"step": meta["step"], "meta": meta}
truncated = False
return obs_ped, reward, term, truncated, {"step": meta["step"], "meta": meta}

def reset(self, seed=None, options=None):
"""
Expand Down
4 changes: 2 additions & 2 deletions robot_sf/render/sim_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,8 +367,8 @@ def _augment_lidar(self, ray_vecs: np.ndarray):
def _augment_action(self, action: VisualizableAction, color):
r_x, r_y = action.pose[0]
# scale vector length to be always visible
vec_length = action.robot_action[0] * self.scaling
vec_orient = action.robot_pose[1]
vec_length = action.action[0] * self.scaling
vec_orient = action.pose[1]

def from_polar(length: float, orient: float) -> Vec2D:
return cos(orient) * length, sin(orient) * length
Expand Down
2 changes: 1 addition & 1 deletion robot_sf/sim/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def get_proximity_point(self, fixed_point: Tuple[float, float],
if not self.is_obstacle_collision(new_x, new_y):
return new_x, new_y

logger.warning("Could not find a valid proximity point: {fixed_point}.")
logger.warning(f"Could not find a valid proximity point: {fixed_point}.")
spawn_id = sample(self.map_def.ped_spawn_zones, k=1)[0] # Spawn in pedestrian spawn_zone
initial_spawn = sample_zone(spawn_id, 1)[0]
return initial_spawn
Expand Down

0 comments on commit fa52769

Please sign in to comment.