Skip to content

Commit

Permalink
Add reward and action plotters
Browse files Browse the repository at this point in the history
  • Loading branch information
JeanElsner committed Nov 17, 2023
1 parent 8d5a955 commit 9bf4168
Showing 1 changed file with 92 additions and 13 deletions.
105 changes: 92 additions & 13 deletions src/dm_robotics/panda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,12 @@ def _time(self, physics: mjcf.Physics) -> np.ndarray:
return np.array([physics.data.time])


class PlotComponent(renderer.Component):
""" A plotting component for `dm_control.viewer.application.Application`. """
class Plot(renderer.Component):

def __init__(self,
runtime: runtime_module.Runtime,
maxlen: int = 500) -> None:
self._rt = runtime
self._obs_idx = None
self._obs_keys = None
self.maxlen = min(maxlen, mujoco.mjMAXLINEPNT)
self.maxlines = 0
self.x = np.linspace(-self.maxlen, 0, self.maxlen)
Expand All @@ -169,6 +166,27 @@ def __init__(self,
self.fig.flg_barplot = 0
self.fig.flg_selection = 0
self.fig.range = [[1, 0], [1, 0]]
self.fig.linewidth = 1.5

def reset_data(self):
for i in range(self.maxlines):
for j in range(self.maxlen):
del j
self.y[i].append(0)


class ObservationPlot(Plot):
"""
Plotting component for :py:class:`dm_control.viewer.application.Application`
that allows you to browse through the observations.
"""

def __init__(self,
runtime: runtime_module.Runtime,
maxlen: int = 1000) -> None:
super().__init__(runtime, maxlen)
self._obs_idx = None
self._obs_keys = None

def _init_buffer(self):
for obs in self._rt._time_step.observation.values():
Expand All @@ -180,13 +198,16 @@ def _init_buffer(self):
self.reset_data()
self._obs_idx = 0
self._obs_keys = list(self._rt._time_step.observation.keys())
self.update_title()

def update_title(self):
self.fig.title = f'{self._obs_keys[self._obs_idx]:100s}'

def render(self, context, viewport):
if self._rt._time_step is None:
return
if self._obs_idx is None:
self._init_buffer()
pos = mujoco.MjrRect(5, viewport.height - 256 - 5, 256, 256)
obs = np.atleast_1d(
self._rt._time_step.observation[self._obs_keys[self._obs_idx]])
for i in range(self.maxlines):
Expand All @@ -197,22 +218,78 @@ def render(self, context, viewport):
]).T.reshape((-1,))
else:
self.fig.linepnt[i] = 0
self.fig.title = f'{self._obs_keys[self._obs_idx]:100s}'
pos = mujoco.MjrRect(5, viewport.height - 200 - 5, 300, 200)
mujoco.mjr_figure(pos, self.fig, context.ptr)

def reset_data(self):
for i in range(self.maxlines):
for j in range(self.maxlen):
del j
self.y[i].append(0)

def next_obs(self):
self._obs_idx = (self._obs_idx + 1) % len(self._obs_keys)
self.reset_data()
self.update_title()

def prev_obs(self):
self._obs_idx = (self._obs_idx - 1) % len(self._obs_keys)
self.reset_data()
self.update_title()


class ActionPlot(Plot):
"""
A plotting component for :py:class:`dm_control.viewer.application.Application`
that plots the agent's actions.
"""

def __init__(self,
runtime: runtime_module.Runtime,
maxlen: int = 1000) -> None:
super().__init__(runtime, maxlen)
self._init_buffer()
self.fig.title = 'Actions'

def _init_buffer(self):
self.maxlines = self._rt._default_action.shape[0]
for _1 in range(self.maxlines):
self.y.append(deque(maxlen=self.maxlen))
self.reset_data()

def render(self, context, viewport):
if self._rt._time_step is None:
return
for i, a in enumerate(self._rt.last_action):
self.fig.linepnt[i] = self.maxlen
self.y[i].append(a)
self.fig.linedata[i][:self.maxlen * 2] = np.array([self.x,
self.y[i]]).T.reshape(
(-1,))
pos = mujoco.MjrRect(300 + 5, viewport.height - 200 - 5, 300, 200)
mujoco.mjr_figure(pos, self.fig, context.ptr)


class RewardPlot(Plot):
"""
A plotting component for :py:class:`dm_control.viewer.application.Application`
that plots the environment's reward.
"""

def __init__(self,
runtime: runtime_module.Runtime,
maxlen: int = 1000) -> None:
super().__init__(runtime, maxlen)
self.fig.title = 'Reward'
self.maxlines = 1
self.y.append(deque(maxlen=self.maxlen))
self.reset_data()

def render(self, context, viewport):
if self._rt._time_step is None:
return
r = self._rt._time_step.reward
self.fig.linepnt[0] = self.maxlen
self.y[0].append(r)
self.fig.linedata[0][:self.maxlen * 2] = np.array([self.x,
self.y[0]]).T.reshape(
(-1,))
pos = mujoco.MjrRect(2 * 300 + 5, viewport.height - 200 - 5, 300, 200)
mujoco.mjr_figure(pos, self.fig, context.ptr)


class PlotHelp(views.ColumnTextModel):
Expand All @@ -235,7 +312,9 @@ def __init__(self, title='Explorer', width=1024, height=768):

def _perform_deferred_reload(self, params):
super()._perform_deferred_reload(params)
cmp = PlotComponent(self._runtime)
cmp = ObservationPlot(self._runtime)
self._renderer.components += cmp
self._renderer.components += ActionPlot(self._runtime)
self._renderer.components += RewardPlot(self._runtime)
self._input_map.bind(cmp.next_obs, user_input.KEY_F4)
self._input_map.bind(cmp.prev_obs, user_input.KEY_F3)

0 comments on commit 9bf4168

Please sign in to comment.