From 9bf416867d71e1c49672d20d26139224a3d3ea3d Mon Sep 17 00:00:00 2001 From: JeanElsner Date: Fri, 17 Nov 2023 08:16:17 +0100 Subject: [PATCH] Add reward and action plotters --- src/dm_robotics/panda/utils.py | 105 +++++++++++++++++++++++++++++---- 1 file changed, 92 insertions(+), 13 deletions(-) diff --git a/src/dm_robotics/panda/utils.py b/src/dm_robotics/panda/utils.py index 7ca9f5e..6d44afe 100644 --- a/src/dm_robotics/panda/utils.py +++ b/src/dm_robotics/panda/utils.py @@ -151,15 +151,12 @@ def _time(self, physics: mjcf.Physics) -> np.ndarray: return np.array([physics.data.time]) -class PlotComponent(renderer.Component): - """ A plotting component for `dm_control.viewer.application.Application`. """ +class Plot(renderer.Component): def __init__(self, runtime: runtime_module.Runtime, maxlen: int = 500) -> None: self._rt = runtime - self._obs_idx = None - self._obs_keys = None self.maxlen = min(maxlen, mujoco.mjMAXLINEPNT) self.maxlines = 0 self.x = np.linspace(-self.maxlen, 0, self.maxlen) @@ -169,6 +166,27 @@ def __init__(self, self.fig.flg_barplot = 0 self.fig.flg_selection = 0 self.fig.range = [[1, 0], [1, 0]] + self.fig.linewidth = 1.5 + + def reset_data(self): + for i in range(self.maxlines): + for j in range(self.maxlen): + del j + self.y[i].append(0) + + +class ObservationPlot(Plot): + """ + Plotting component for :py:class:`dm_control.viewer.application.Application` + that allows you to browse through the observations. + """ + + def __init__(self, + runtime: runtime_module.Runtime, + maxlen: int = 1000) -> None: + super().__init__(runtime, maxlen) + self._obs_idx = None + self._obs_keys = None def _init_buffer(self): for obs in self._rt._time_step.observation.values(): @@ -180,13 +198,16 @@ def _init_buffer(self): self.reset_data() self._obs_idx = 0 self._obs_keys = list(self._rt._time_step.observation.keys()) + self.update_title() + + def update_title(self): + self.fig.title = f'{self._obs_keys[self._obs_idx]:100s}' def render(self, context, viewport): if self._rt._time_step is None: return if self._obs_idx is None: self._init_buffer() - pos = mujoco.MjrRect(5, viewport.height - 256 - 5, 256, 256) obs = np.atleast_1d( self._rt._time_step.observation[self._obs_keys[self._obs_idx]]) for i in range(self.maxlines): @@ -197,22 +218,78 @@ def render(self, context, viewport): ]).T.reshape((-1,)) else: self.fig.linepnt[i] = 0 - self.fig.title = f'{self._obs_keys[self._obs_idx]:100s}' + pos = mujoco.MjrRect(5, viewport.height - 200 - 5, 300, 200) mujoco.mjr_figure(pos, self.fig, context.ptr) - def reset_data(self): - for i in range(self.maxlines): - for j in range(self.maxlen): - del j - self.y[i].append(0) - def next_obs(self): self._obs_idx = (self._obs_idx + 1) % len(self._obs_keys) self.reset_data() + self.update_title() def prev_obs(self): self._obs_idx = (self._obs_idx - 1) % len(self._obs_keys) self.reset_data() + self.update_title() + + +class ActionPlot(Plot): + """ + A plotting component for :py:class:`dm_control.viewer.application.Application` + that plots the agent's actions. + """ + + def __init__(self, + runtime: runtime_module.Runtime, + maxlen: int = 1000) -> None: + super().__init__(runtime, maxlen) + self._init_buffer() + self.fig.title = 'Actions' + + def _init_buffer(self): + self.maxlines = self._rt._default_action.shape[0] + for _1 in range(self.maxlines): + self.y.append(deque(maxlen=self.maxlen)) + self.reset_data() + + def render(self, context, viewport): + if self._rt._time_step is None: + return + for i, a in enumerate(self._rt.last_action): + self.fig.linepnt[i] = self.maxlen + self.y[i].append(a) + self.fig.linedata[i][:self.maxlen * 2] = np.array([self.x, + self.y[i]]).T.reshape( + (-1,)) + pos = mujoco.MjrRect(300 + 5, viewport.height - 200 - 5, 300, 200) + mujoco.mjr_figure(pos, self.fig, context.ptr) + + +class RewardPlot(Plot): + """ + A plotting component for :py:class:`dm_control.viewer.application.Application` + that plots the environment's reward. + """ + + def __init__(self, + runtime: runtime_module.Runtime, + maxlen: int = 1000) -> None: + super().__init__(runtime, maxlen) + self.fig.title = 'Reward' + self.maxlines = 1 + self.y.append(deque(maxlen=self.maxlen)) + self.reset_data() + + def render(self, context, viewport): + if self._rt._time_step is None: + return + r = self._rt._time_step.reward + self.fig.linepnt[0] = self.maxlen + self.y[0].append(r) + self.fig.linedata[0][:self.maxlen * 2] = np.array([self.x, + self.y[0]]).T.reshape( + (-1,)) + pos = mujoco.MjrRect(2 * 300 + 5, viewport.height - 200 - 5, 300, 200) + mujoco.mjr_figure(pos, self.fig, context.ptr) class PlotHelp(views.ColumnTextModel): @@ -235,7 +312,9 @@ def __init__(self, title='Explorer', width=1024, height=768): def _perform_deferred_reload(self, params): super()._perform_deferred_reload(params) - cmp = PlotComponent(self._runtime) + cmp = ObservationPlot(self._runtime) self._renderer.components += cmp + self._renderer.components += ActionPlot(self._runtime) + self._renderer.components += RewardPlot(self._runtime) self._input_map.bind(cmp.next_obs, user_input.KEY_F4) self._input_map.bind(cmp.prev_obs, user_input.KEY_F3)