From 9e8ac5c8ba1f498ac6fe1303d53761ca65be1514 Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Wed, 11 Dec 2024 00:08:49 +0000 Subject: [PATCH] feat: Scaled rewards and target velocities (#10) * Use channels view parameters * Rename parameters * Include step-number in observation * Add velocity field to targets * Add time scaled reward function --- .../swarms/search_and_rescue/conftest.py | 2 +- .../swarms/search_and_rescue/dynamics.py | 19 ++++---- .../swarms/search_and_rescue/env.py | 43 +++++++++++-------- .../swarms/search_and_rescue/env_test.py | 16 ++++--- .../swarms/search_and_rescue/generator.py | 5 ++- .../swarms/search_and_rescue/observations.py | 12 +++--- .../search_and_rescue/observations_test.py | 10 ++++- .../swarms/search_and_rescue/reward.py | 24 +++++++++-- .../swarms/search_and_rescue/reward_test.py | 16 ++++++- .../swarms/search_and_rescue/types.py | 6 ++- .../swarms/search_and_rescue/utils_test.py | 13 ++++-- 11 files changed, 115 insertions(+), 51 deletions(-) diff --git a/jumanji/environments/swarms/search_and_rescue/conftest.py b/jumanji/environments/swarms/search_and_rescue/conftest.py index 70cb1b907..6b63645aa 100644 --- a/jumanji/environments/swarms/search_and_rescue/conftest.py +++ b/jumanji/environments/swarms/search_and_rescue/conftest.py @@ -28,7 +28,7 @@ def env() -> SearchAndRescue: searcher_min_speed=0.01, searcher_max_speed=0.05, searcher_view_angle=0.5, - max_steps=25, + time_limit=10, ) diff --git a/jumanji/environments/swarms/search_and_rescue/dynamics.py b/jumanji/environments/swarms/search_and_rescue/dynamics.py index d9abe5b19..63353dcb7 100644 --- a/jumanji/environments/swarms/search_and_rescue/dynamics.py +++ b/jumanji/environments/swarms/search_and_rescue/dynamics.py @@ -17,18 +17,20 @@ import chex import jax +from jumanji.environments.swarms.search_and_rescue.types import TargetState + class TargetDynamics(abc.ABC): @abc.abstractmethod - def __call__(self, key: chex.PRNGKey, target_pos: chex.Array) -> chex.Array: + def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> TargetState: """Interface for target position update function. Args: key: random key. - target_pos: Current target positions. + targets: Current target states. Returns: - Updated target positions. + Updated target states. """ @@ -46,16 +48,17 @@ def __init__(self, step_size: float): """ self.step_size = step_size - def __call__(self, key: chex.PRNGKey, target_pos: chex.Array) -> chex.Array: + def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> TargetState: """Update target positions. Args: key: random key. - target_pos: Current target positions. + targets: Current target states. Returns: - Updated target positions. + Updated target states. """ - d_pos = jax.random.uniform(key, target_pos.shape) + d_pos = jax.random.uniform(key, targets.pos.shape) d_pos = self.step_size * 2.0 * (d_pos - 0.5) - return target_pos + d_pos + pos = (targets.pos + d_pos) % env_size + return TargetState(pos=pos, vel=targets.vel, found=targets.found) diff --git a/jumanji/environments/swarms/search_and_rescue/env.py b/jumanji/environments/swarms/search_and_rescue/env.py index 267975854..609e8ba68 100644 --- a/jumanji/environments/swarms/search_and_rescue/env.py +++ b/jumanji/environments/swarms/search_and_rescue/env.py @@ -113,7 +113,7 @@ def __init__( searcher_min_speed: float = 0.01, searcher_max_speed: float = 0.02, searcher_view_angle: float = 0.75, - max_steps: int = 400, + time_limit: int = 400, viewer: Optional[Viewer[State]] = None, target_dynamics: Optional[TargetDynamics] = None, generator: Optional[Generator] = None, @@ -136,7 +136,7 @@ def __init__( The view cone of an agent goes from +- of the view angle relative to its heading, e.g. 0.5 would mean searchers have a 90° view angle in total. - max_steps: Maximum number of environment steps allowed for search. + time_limit: Maximum number of environment steps allowed for search. viewer: `Viewer` used for rendering. Defaults to `SearchAndRescueViewer`. target_dynamics: target_dynamics: Target object dynamics model, implemented as a @@ -156,7 +156,7 @@ def __init__( max_speed=searcher_max_speed, view_angle=searcher_view_angle, ) - self.max_steps = max_steps + self.time_limit = time_limit self._target_dynamics = target_dynamics or RandomWalk(0.001) self.generator = generator or RandomGenerator(num_targets=100, num_searchers=2) self._viewer = viewer or SearchAndRescueViewer() @@ -180,7 +180,7 @@ def __repr__(self) -> str: f" - target contact range: {self.target_contact_range}", f" - num vision: {self._observation.num_vision}", f" - agent radius: {self._observation.agent_radius}", - f" - max steps: {self.max_steps}," + f" - time limit: {self.time_limit}," f" - env size: {self.generator.env_size}" f" - target dynamics: {self._target_dynamics.__class__.__name__}", f" - generator: {self.generator.__class__.__name__}", @@ -223,11 +223,12 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser searchers = update_state( key, self.generator.env_size, self.searcher_params, state.searchers, actions ) - # Ensure target positions are wrapped - target_pos = self._target_dynamics(target_key, state.targets.pos) % self.generator.env_size + + targets = self._target_dynamics(target_key, state.targets, self.generator.env_size) + # Searchers return an array of flags of any targets they are in range of, # and that have not already been located, result shape here is (n-searcher, n-targets) - n_targets = target_pos.shape[0] + n_targets = targets.pos.shape[0] targets_found = spatial( utils.searcher_detect_targets, reduction=jnp.logical_or, @@ -238,14 +239,14 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser key, self.searcher_params.view_angle, searchers, - (jnp.arange(n_targets), state.targets), + (jnp.arange(n_targets), targets), pos=searchers.pos, - pos_b=target_pos, + pos_b=targets.pos, env_size=self.generator.env_size, n_targets=n_targets, ) - rewards = self._reward_fn(targets_found) + rewards = self._reward_fn(targets_found, state.step, self.time_limit) targets_found = jnp.any(targets_found, axis=0) # Targets need to remain found if they already have been @@ -253,14 +254,14 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser state = State( searchers=searchers, - targets=TargetState(pos=target_pos, found=targets_found), + targets=TargetState(pos=targets.pos, vel=targets.vel, found=targets_found), key=key, step=state.step + 1, ) observation = self._state_to_observation(state) observation = jax.lax.stop_gradient(observation) timestep = jax.lax.cond( - jnp.logical_or(state.step >= self.max_steps, jnp.all(targets_found)), + jnp.logical_or(state.step >= self.time_limit, jnp.all(targets_found)), termination, transition, rewards, @@ -273,9 +274,13 @@ def _state_to_observation(self, state: State) -> Observation: return Observation( searcher_views=searcher_views, targets_remaining=1.0 - jnp.sum(state.targets.found) / self.generator.num_targets, - time_remaining=1.0 - state.step / (self.max_steps + 1), + step=state.step, ) + @cached_property + def num_agents(self) -> int: + return self.generator.num_searchers + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec. @@ -287,7 +292,11 @@ def observation_spec(self) -> specs.Spec[Observation]: observation_spec: Search-and-rescue observation spec """ searcher_views = specs.BoundedArray( - shape=(self.generator.num_searchers, *self._observation.view_shape), + shape=( + self.generator.num_searchers, + self._observation.num_channels, + self._observation.num_vision, + ), minimum=-1.0, maximum=1.0, dtype=float, @@ -298,10 +307,10 @@ def observation_spec(self) -> specs.Spec[Observation]: "ObservationSpec", searcher_views=searcher_views, targets_remaining=specs.BoundedArray( - shape=(), minimum=0.0, maximum=1.0, name="targets_remaining", dtype=float + shape=(), minimum=0.0, maximum=1.0, name="targets_remaining", dtype=jnp.float32 ), - time_remaining=specs.BoundedArray( - shape=(), minimum=0.0, maximum=1.0, name="time_remaining", dtype=float + step=specs.BoundedArray( + shape=(), minimum=0, maximum=self.time_limit, name="step", dtype=jnp.int32 ), ) diff --git a/jumanji/environments/swarms/search_and_rescue/env_test.py b/jumanji/environments/swarms/search_and_rescue/env_test.py index 998045c9b..4f0b051d6 100644 --- a/jumanji/environments/swarms/search_and_rescue/env_test.py +++ b/jumanji/environments/swarms/search_and_rescue/env_test.py @@ -57,7 +57,8 @@ def test_env_init(env: SearchAndRescue, key: chex.PRNGKey) -> None: assert isinstance(timestep.observation, Observation) assert timestep.observation.searcher_views.shape == ( env.generator.num_searchers, - *env._observation.view_shape, + env._observation.num_channels, + env._observation.num_vision, ) assert timestep.step_type == StepType.FIRST @@ -69,8 +70,9 @@ def test_env_step(env: SearchAndRescue, key: chex.PRNGKey, env_size: float) -> N check states (i.e. positions, heading, speeds) all fall inside expected ranges. """ - n_steps = 22 + n_steps = env.time_limit env.generator.env_size = env_size + env.time_limit = 22 def step( carry: Tuple[chex.PRNGKey, State], _: None @@ -108,7 +110,7 @@ def step( def test_env_does_not_smoke(env: SearchAndRescue) -> None: """Test that we can run an episode without any errors.""" - env.max_steps = 10 + env.time_limit = 10 def select_action(action_key: chex.PRNGKey, _state: Observation) -> chex.Array: return jax.random.uniform( @@ -132,7 +134,9 @@ def test_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None: searchers=AgentState( pos=jnp.array([[0.5, 0.5]]), heading=jnp.array([jnp.pi]), speed=jnp.array([0.0]) ), - targets=TargetState(pos=jnp.array([[0.54, 0.5]]), found=jnp.array([False])), + targets=TargetState( + pos=jnp.array([[0.54, 0.5]]), vel=jnp.zeros((1, 2)), found=jnp.array([False]) + ), key=key, ) state, timestep = env.step(state, jnp.zeros((1, 2))) @@ -188,7 +192,9 @@ def test_multi_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None pos=jnp.array([[0.5, 0.5]]), heading=jnp.array([0.5 * jnp.pi]), speed=jnp.array([0.0]) ), targets=TargetState( - pos=jnp.array([[0.54, 0.5], [0.46, 0.5]]), found=jnp.array([False, False]) + pos=jnp.array([[0.54, 0.5], [0.46, 0.5]]), + vel=jnp.zeros((2, 2)), + found=jnp.array([False, False]), ), key=key, ) diff --git a/jumanji/environments/swarms/search_and_rescue/generator.py b/jumanji/environments/swarms/search_and_rescue/generator.py index 3a3c85251..e0d627db5 100644 --- a/jumanji/environments/swarms/search_and_rescue/generator.py +++ b/jumanji/environments/swarms/search_and_rescue/generator.py @@ -83,11 +83,14 @@ def __call__(self, key: chex.PRNGKey, searcher_params: AgentParams) -> State: target_pos = jax.random.uniform( target_key, (self.num_targets, 2), minval=0.0, maxval=self.env_size ) + target_vel = jnp.zeros((self.num_targets, 2)) state = State( searchers=searcher_state, targets=TargetState( - pos=target_pos, found=jnp.full((self.num_targets,), False, dtype=bool) + pos=target_pos, + vel=target_vel, + found=jnp.full((self.num_targets,), False, dtype=bool), ), key=key, ) diff --git a/jumanji/environments/swarms/search_and_rescue/observations.py b/jumanji/environments/swarms/search_and_rescue/observations.py index f3f917e59..9779c2b33 100644 --- a/jumanji/environments/swarms/search_and_rescue/observations.py +++ b/jumanji/environments/swarms/search_and_rescue/observations.py @@ -27,7 +27,7 @@ class ObservationFn(abc.ABC): def __init__( self, - view_shape: Tuple[int, ...], + num_channels: int, num_vision: int, vision_range: float, view_angle: float, @@ -38,14 +38,14 @@ def __init__( Base class for observation function mapping state to individual agent views. Args: - view_shape: Individual agent view shape. + num_channels: Number of channels in agent view. num_vision: Size of vision array. vision_range: Vision range. view_angle: Agent view angle (as a fraction of pi). agent_radius: Agent/target visual radius. env_size: Environment size. """ - self.view_shape = view_shape + self.num_channels = num_channels self.num_vision = num_vision self.vision_range = vision_range self.view_angle = view_angle @@ -85,7 +85,7 @@ def __init__( env_size: Environment size. """ super().__init__( - (1, num_vision), + 1, num_vision, vision_range, view_angle, @@ -199,7 +199,7 @@ def __init__( self.agent_radius = agent_radius self.env_size = env_size super().__init__( - (2, num_vision), + 2, num_vision, vision_range, view_angle, @@ -333,7 +333,7 @@ def __init__( self.agent_radius = agent_radius self.env_size = env_size super().__init__( - (3, num_vision), + 3, num_vision, vision_range, view_angle, diff --git a/jumanji/environments/swarms/search_and_rescue/observations_test.py b/jumanji/environments/swarms/search_and_rescue/observations_test.py index 3b02b3e42..17b092fa4 100644 --- a/jumanji/environments/swarms/search_and_rescue/observations_test.py +++ b/jumanji/environments/swarms/search_and_rescue/observations_test.py @@ -84,7 +84,9 @@ def test_searcher_view( searchers=AgentState( pos=searcher_positions, heading=searcher_headings, speed=searcher_speed ), - targets=TargetState(pos=jnp.zeros((1, 2)), found=jnp.zeros((1, 2), dtype=bool)), + targets=TargetState( + pos=jnp.zeros((1, 2)), vel=jnp.zeros((1, 2)), found=jnp.zeros((1, 2), dtype=bool) + ), key=key, ) @@ -164,7 +166,9 @@ def test_search_and_target_view_searchers( searchers=AgentState( pos=searcher_positions, heading=searcher_headings, speed=searcher_speed ), - targets=TargetState(pos=jnp.zeros((1, 2)), found=jnp.zeros((1,), dtype=bool)), + targets=TargetState( + pos=jnp.zeros((1, 2)), vel=jnp.zeros((1, 2)), found=jnp.zeros((1,), dtype=bool) + ), key=key, ) @@ -241,6 +245,7 @@ def test_search_and_target_view_targets( searchers=AgentState(pos=searcher_position, heading=searcher_heading, speed=searcher_speed), targets=TargetState( pos=target_position, + vel=jnp.zeros_like(target_position), found=target_found, ), key=key, @@ -328,6 +333,7 @@ def test_search_and_all_target_view_targets( searchers=AgentState(pos=searcher_position, heading=searcher_heading, speed=searcher_speed), targets=TargetState( pos=target_position, + vel=jnp.zeros_like(target_position), found=target_found, ), key=key, diff --git a/jumanji/environments/swarms/search_and_rescue/reward.py b/jumanji/environments/swarms/search_and_rescue/reward.py index 1217b86f7..720adc3fa 100644 --- a/jumanji/environments/swarms/search_and_rescue/reward.py +++ b/jumanji/environments/swarms/search_and_rescue/reward.py @@ -22,7 +22,7 @@ class RewardFn(abc.ABC): """Abstract class for `SearchAndRescue` rewards.""" @abc.abstractmethod - def __call__(self, found_targets: chex.Array) -> chex.Array: + def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: """The reward function used in the `SearchAndRescue` environment. Args: @@ -41,7 +41,7 @@ class SharedRewardFn(RewardFn): can receive rewards for detecting multiple targets. """ - def __call__(self, found_targets: chex.Array) -> chex.Array: + def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: rewards = found_targets.astype(float) norms = jnp.sum(rewards, axis=0)[jnp.newaxis] rewards = jnp.where(norms > 0, rewards / norms, rewards) @@ -57,7 +57,25 @@ class IndividualRewardFn(RewardFn): even if a target is detected by multiple agents. """ - def __call__(self, found_targets: chex.Array) -> chex.Array: + def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: rewards = found_targets.astype(float) rewards = jnp.sum(rewards, axis=1) return rewards + + +class SharedScaledRewardFn(RewardFn): + """ + Calculate per agent rewards from detected targets + + Targets detected by multiple agents share rewards. Agents + can receive rewards for detecting multiple targets. + Rewards are scaled by the current time step. + """ + + def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: + rewards = found_targets.astype(float) + norms = jnp.sum(rewards, axis=0)[jnp.newaxis] + rewards = jnp.where(norms > 0, rewards / norms, rewards) + rewards = jnp.sum(rewards, axis=1) + scale = (time_limit - step) / time_limit + return scale * rewards diff --git a/jumanji/environments/swarms/search_and_rescue/reward_test.py b/jumanji/environments/swarms/search_and_rescue/reward_test.py index 303b48b5c..e43590871 100644 --- a/jumanji/environments/swarms/search_and_rescue/reward_test.py +++ b/jumanji/environments/swarms/search_and_rescue/reward_test.py @@ -20,14 +20,26 @@ def test_rewards_from_found_targets() -> None: targets_found = jnp.array([[False, True, True], [False, False, True]], dtype=bool) - shared_rewards = reward.SharedRewardFn()(targets_found) + shared_rewards = reward.SharedRewardFn()(targets_found, 0, 10) assert shared_rewards.shape == (2,) assert shared_rewards.dtype == jnp.float32 assert jnp.allclose(shared_rewards, jnp.array([1.5, 0.5])) - individual_rewards = reward.IndividualRewardFn()(targets_found) + individual_rewards = reward.IndividualRewardFn()(targets_found, 0, 10) assert individual_rewards.shape == (2,) assert individual_rewards.dtype == jnp.float32 assert jnp.allclose(individual_rewards, jnp.array([2.0, 1.0])) + + shared_scaled_rewards = reward.SharedScaledRewardFn()(targets_found, 0, 10) + + assert shared_scaled_rewards.shape == (2,) + assert shared_scaled_rewards.dtype == jnp.float32 + assert jnp.allclose(shared_scaled_rewards, jnp.array([1.5, 0.5])) + + shared_scaled_rewards = reward.SharedScaledRewardFn()(targets_found, 10, 10) + + assert shared_scaled_rewards.shape == (2,) + assert shared_scaled_rewards.dtype == jnp.float32 + assert jnp.allclose(shared_scaled_rewards, jnp.array([0.0, 0.0])) diff --git a/jumanji/environments/swarms/search_and_rescue/types.py b/jumanji/environments/swarms/search_and_rescue/types.py index cb924bd60..28c9600bb 100644 --- a/jumanji/environments/swarms/search_and_rescue/types.py +++ b/jumanji/environments/swarms/search_and_rescue/types.py @@ -29,11 +29,13 @@ class TargetState: The state for the rescue targets. pos: 2d position of the target agents + velocity: 2d velocity of the target agents found: Boolean flag indicating if the target has been located by a searcher. """ pos: chex.Array # (num_targets, 2) + vel: chex.Array # (num_targets, 2) found: chex.Array # (num_targets,) @@ -75,5 +77,5 @@ class Observation(NamedTuple): """ searcher_views: chex.Array # (num_searchers, num_vision) - targets_remaining: chex.Array # () - time_remaining: chex.Array # () + targets_remaining: chex.Numeric # () + step: chex.Numeric # () diff --git a/jumanji/environments/swarms/search_and_rescue/utils_test.py b/jumanji/environments/swarms/search_and_rescue/utils_test.py index 0f7328c43..018e895a7 100644 --- a/jumanji/environments/swarms/search_and_rescue/utils_test.py +++ b/jumanji/environments/swarms/search_and_rescue/utils_test.py @@ -30,14 +30,18 @@ def test_random_walk_dynamics(key: chex.PRNGKey) -> None: n_targets = 50 - s0 = jnp.full((n_targets, 2), 0.5) + pos_0 = jnp.full((n_targets, 2), 0.5) + + s0 = TargetState( + pos=pos_0, vel=jnp.zeros((n_targets, 2)), found=jnp.zeros((n_targets,), dtype=bool) + ) dynamics = RandomWalk(0.1) assert isinstance(dynamics, TargetDynamics) - s1 = dynamics(key, s0) + s1 = dynamics(key, s0, 1.0) - assert s1.shape == (n_targets, 2) - assert jnp.all(jnp.abs(s0 - s1) < 0.1) + assert s1.pos.shape == (n_targets, 2) + assert jnp.all(jnp.abs(s0.pos - s1.pos) < 0.1) @pytest.mark.parametrize( @@ -66,6 +70,7 @@ def test_target_found( ) -> None: target = TargetState( pos=jnp.zeros((2,)), + vel=jnp.zeros((2,)), found=target_state, )