Skip to content

Commit

Permalink
feat: extend reward plot for multiple signals (#11)
Browse files Browse the repository at this point in the history
* feat: extend reward plot for multiple signals

* fix: differentiate between scalar and array reward in plot utility
  • Loading branch information
JeanElsner committed Jun 21, 2024
1 parent 1d51de4 commit 33ebfb8
Showing 1 changed file with 25 additions and 9 deletions.
34 changes: 25 additions & 9 deletions src/dm_robotics/panda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,19 +256,35 @@ def __init__(self,
maxlen: int = 500) -> None:
super().__init__(runtime, maxlen)
self.fig.title = 'Reward'
self.maxlines = 1
self.y.append(deque(maxlen=self.maxlen))
self.maxlines = None

def _init_buffer(self):
if isinstance(self._rt._time_step.reward, np.ndarray):
self.maxlines = self._rt._time_step.reward.shape[0]
else:
self.maxlines = 1
for _1 in range(self.maxlines):
self.y.append(deque(maxlen=self.maxlen))
self.reset_data()

def render(self, context, viewport):
if self._rt._time_step is None:
if self._rt._time_step is None or self._rt._time_step.reward is None:
return
r = self._rt._time_step.reward
self.fig.linepnt[0] = self.maxlen
self.y[0].append(r)
self.fig.linedata[0][:self.maxlen * 2] = np.array([self.x,
self.y[0]]).T.reshape(
(-1,))
if self.maxlines is None:
self._init_buffer()
if self.maxlines > 1:
for i, r in enumerate(self._rt._time_step.reward):
self.fig.linepnt[i] = self.maxlen
self.y[i].append(r)
self.fig.linedata[i][:self.maxlen * 2] = np.array([self.x, self.y[i]
]).T.reshape((-1,))
else:
r = self._rt._time_step.reward
self.fig.linepnt[0] = self.maxlen
self.y[0].append(r)
self.fig.linedata[0][:self.maxlen * 2] = np.array([self.x,
self.y[0]]).T.reshape(
(-1,))
pos = mujoco.MjrRect(2 * 300 + 5, viewport.height - 200 - 5, 300, 200)
mujoco.mjr_figure(pos, self.fig, context.ptr)

Expand Down

0 comments on commit 33ebfb8

Please sign in to comment.