Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement predator-prey flock environment #258

Prev Previous commit
Next Next commit
feat: Implement pred-prey environment viewer
  • Loading branch information
zombie-einstein committed Nov 3, 2024
commit 9431b26fff80a4f8fa32750abcb160934713ca49
20 changes: 16 additions & 4 deletions jumanji/environments/swarms/predator_prey/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from functools import cached_property
from typing import Any, Tuple
from typing import Optional, Tuple

import chex
import jax
Expand All @@ -25,6 +25,7 @@
from jumanji.environments.swarms.common.types import AgentParams
from jumanji.environments.swarms.common.updates import init_state, update_state, view
from jumanji.types import TimeStep, restart, termination, transition
from jumanji.viewer import Viewer

from .types import Actions, Observation, Rewards, State
from .updates import (
Expand All @@ -33,6 +34,7 @@
sparse_predator_rewards,
sparse_prey_rewards,
)
from .viewer import PredatorPreyViewer


class PredatorPrey(Environment):
Expand Down Expand Up @@ -115,8 +117,10 @@ class PredatorPrey(Environment):
)
key = jax.random.PRNGKey(0)
state, timestep = jax.jit(env.reset)(key)
env.render(state)
action = env.action_spec.generate_value()
state, timestep = jax.jit(env.step)(state, action)
env.render(state)
```
"""

Expand All @@ -142,6 +146,7 @@ def __init__(
prey_max_speed: float,
prey_view_angle: float,
max_steps: int = 10_000,
viewer: Optional[Viewer[State]] = None,
) -> None:
"""
Instantiates a `PredatorPrey` environment
Expand Down Expand Up @@ -198,6 +203,7 @@ def __init__(
The view cone pf an agent goes from +- of the view angle
relative to its heading.
max_steps: Maximum number of environment steps before termination
viewer: `Viewer` used for rendering. Defaults to `PredatorPreyViewer`.
"""
self.num_predators = num_predators
self.num_prey = num_prey
Expand All @@ -224,6 +230,7 @@ def __init__(
)
self.max_steps = max_steps
super().__init__()
self._viewer = viewer or PredatorPreyViewer()

def __repr__(self) -> str:
return "\n".join(
Expand Down Expand Up @@ -531,9 +538,14 @@ def reward_spec(self) -> specs.Spec[Rewards]: # type: ignore[override]
prey=prey,
)

def render(self, state: State) -> Any:
"""Not currently implemented for this environment"""
raise NotImplementedError("Render method not implemented for this environment.")
def render(self, state: State) -> None:
"""
Render frames of the environment for a given state using matplotlib.

Args:
state: State object containing the current dynamics of the environment.
"""
self._viewer.render(state)

def close(self) -> None:
pass
191 changes: 191 additions & 0 deletions jumanji/environments/swarms/predator_prey/viewer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional, Sequence, Tuple

import jax.numpy as jnp
import matplotlib.animation
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.quiver import Quiver

import jumanji
import jumanji.environments
from jumanji.environments.swarms.common.types import AgentState
from jumanji.environments.swarms.predator_prey.types import State
from jumanji.viewer import Viewer


class PredatorPreyViewer(Viewer):
def __init__(
self,
figure_name: str = "PredatorPrey",
figure_size: Tuple[float, float] = (6.0, 6.0),
predator_color: str = "red",
prey_color: str = "green",
) -> None:
"""Viewer for the `PredatorPrey` environment.

Args:
figure_name: the window name to be used when initialising the window.
figure_size: tuple (height, width) of the matplotlib figure window.
"""
self._figure_name = figure_name
self._figure_size = figure_size
self.predator_color = predator_color
self.prey_color = prey_color
self._animation: Optional[matplotlib.animation.Animation] = None

def render(self, state: State) -> None:
"""
Render frames of the environment for a given state using matplotlib.

Args:
state: State object containing the current dynamics of the environment.

"""
self._clear_display()
fig, ax = self._get_fig_ax()
self._draw(ax, state)
self._update_display(fig)

def animate(
self, states: Sequence[State], interval: int, save_path: Optional[str]
) -> matplotlib.animation.FuncAnimation:
"""
Create an animation from a sequence of states.

Args:
states: sequence of `State` corresponding to subsequent timesteps.
interval: delay between frames in milliseconds, default to 200.
save_path: the path where the animation file should be saved. If it is None, the plot
will not be saved.

Returns:
Animation object that can be saved as a GIF, MP4, or rendered with HTML.
"""
if not states:
raise ValueError(f"The states argument has to be non-empty, got {states}.")
fig, ax = plt.subplots(
num=f"{self._figure_name}Anim", figsize=self._figure_size
)
fig, ax = self._format_plot(fig, ax)

predators_quiver = self._draw_agents(
ax, states[0].predators, self.predator_color
)
prey_quiver = self._draw_agents(ax, states[0].prey, self.prey_color)

def make_frame(state: State) -> Any:
predators_quiver.set_offsets(state.predators.pos)
predators_quiver.set_UVC(
jnp.cos(state.predators.heading), jnp.sin(state.predators.heading)
)
prey_quiver.set_offsets(state.prey.pos)
prey_quiver.set_UVC(
jnp.cos(state.prey.heading), jnp.sin(state.prey.heading)
)
return ((predators_quiver, prey_quiver),)

# Create the animation object.
matplotlib.rc("animation", html="jshtml")
self._animation = matplotlib.animation.FuncAnimation(
fig,
make_frame,
frames=states,
interval=interval,
blit=False,
)

# Save the animation as a gif.
if save_path:
self._animation.save(save_path)

return self._animation

def close(self) -> None:
"""Perform any necessary cleanup.

Environments will automatically :meth:`close()` themselves when
garbage collected or when the program exits.
"""
plt.close(self._figure_name)

def _draw(self, ax: plt.Axes, state: State) -> None:
ax.clear()
self._draw_agents(ax, state.predators, self.predator_color)
self._draw_agents(ax, state.prey, self.prey_color)

def _draw_agents(
self, ax: plt.Axes, agent_states: AgentState, color: str
) -> Quiver:
q = ax.quiver(
agent_states.pos[:, 0],
agent_states.pos[:, 1],
jnp.cos(agent_states.heading),
jnp.sin(agent_states.heading),
color=color,
pivot="middle",
)
return q

def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]:
exists = plt.fignum_exists(self._figure_name)
if exists:
fig = plt.figure(self._figure_name)
ax = fig.get_axes()[0]
else:
fig = plt.figure(self._figure_name, figsize=self._figure_size)
fig.set_tight_layout({"pad": False, "w_pad": 0.0, "h_pad": 0.0})
if not plt.isinteractive():
fig.show()
ax = fig.add_subplot()

fig, ax = self._format_plot(fig, ax)
return fig, ax

def _format_plot(self, fig: Figure, ax: Axes) -> Tuple[Figure, Axes]:
border = 0.01
fig.subplots_adjust(
top=1.0 - border,
bottom=border,
right=1.0 - border,
left=border,
hspace=0,
wspace=0,
)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

return fig, ax

def _update_display(self, fig: plt.Figure) -> None:
if plt.isinteractive():
# Required to update render when using Jupyter Notebook.
fig.canvas.draw()
if jumanji.environments.is_colab():
plt.show(self._figure_name)
else:
# Required to update render when not using Jupyter Notebook.
fig.canvas.draw_idle()
fig.canvas.flush_events()

def _clear_display(self) -> None:
if jumanji.environments.is_colab():
import IPython.display

IPython.display.clear_output(True)