Skip to content

Commit

Permalink
Update running stats.
Browse files Browse the repository at this point in the history
- now updating the stats with batches allowed.
  • Loading branch information
robfiras committed Jan 22, 2024
1 parent a9f3f1a commit 27845c2
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 20 deletions.
12 changes: 12 additions & 0 deletions mushroom_rl/core/array_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def abs(array):
def clip(array, min, max):
raise NotImplementedError

@staticmethod
def atleast_2d(array):
raise NotImplementedError

@staticmethod
def copy(array):
raise NotImplementedError
Expand Down Expand Up @@ -223,6 +227,10 @@ def abs(array):
def clip(array, min, max):
return np.clip(array, min, max)

@staticmethod
def atleast_2d(array):
return np.atleast_2d(array)

@staticmethod
def copy(array):
return array.copy()
Expand Down Expand Up @@ -328,6 +336,10 @@ def abs(array):
def clip(array, min, max):
return torch.clip(array, min, max)

@staticmethod
def atleast_2d(array):
return torch.atleast_2d(array)

@staticmethod
def copy(array):
return array.clone()
Expand Down
4 changes: 0 additions & 4 deletions mushroom_rl/rl_utils/preprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,6 @@ def __init__(self, mdp_info, backend, clip_obs=10., alpha=1e-32):
)

def __call__(self, obs):
assert obs.shape == self._obs_shape, \
"Values given to running_norm have incorrect shape " \
"(obs shape: {}, expected shape: {})" \
.format(obs.shape, self._obs_shape)

norm_obs = self._array_backend.clip(
(obs - self._obs_runstand.mean) / self._obs_runstand.std,
Expand Down
39 changes: 23 additions & 16 deletions mushroom_rl/rl_utils/running_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,23 @@ def reset(self):
"""
self._n = 1
self._m = self._array_backend.zeros(*self._shape)
self._s = self._array_backend.ones(*self._shape)
self._m = self._array_backend.zeros(1, *self._shape)
self._s = self._array_backend.ones(1, *self._shape)

def update_stats(self, value):
"""
Update the statistics with the current data value.
Args:
value (np.ndarray): current data value to use for the update.
value (Array): current data value to use for the update.
"""
self._n += 1
alpha = max(1. / self._n, self._alpha)
new_m = (1 - alpha) * self._m + alpha * value
new_s = self._s + (value - self._m) * (value - new_m)
value = self._array_backend.atleast_2d(value)
batch_size = len(value)
self._n += batch_size
alpha = max(batch_size / self._n, self._alpha)
new_m = (1 - alpha) * self._m + alpha * value.mean(0)
new_s = self._s + (value.mean(0) - self._m) * (value.mean(0) - new_m)
self._m, self._s = new_m, new_s

@property
Expand All @@ -70,7 +72,7 @@ def mean(self):
The estimated mean value.
"""
return self._m
return self._array_backend.squeeze(self._m)

@property
def std(self):
Expand All @@ -79,7 +81,7 @@ def std(self):
The estimated standard deviation value.
"""
return self._array_backend.sqrt(self._s / self._n)
return self._array_backend.squeeze(self._array_backend.sqrt(self._s / self._n))


class RunningExpWeightedAverage(Serializable):
Expand Down Expand Up @@ -116,24 +118,26 @@ def reset(self, init_value=None):
Reset the mean and standard deviation.
Args:
init_value (np.ndarray): initial value of the filter.
init_value (Array): initial value of the filter.
"""
if init_value is None:
self._avg_value = self._array_backend.zeros(*self._shape)
self._avg_value = self._array_backend.zeros(1, *self._shape)
else:
self._avg_value = self._array_backend.convert(init_value)
self._avg_value = self._array_backend.atleast_2d(self._array_backend.convert(init_value))

def update_stats(self, value):
"""
Update the statistics with the current data value.
Args:
value (np.ndarray): current data value to use for the update.
value (Array): current data value to use for the update.
"""
self._avg_value = (
1. - self._alpha) * self._avg_value + self._alpha * value
value = self._array_backend.atleast_2d(value)
batch_size = len(value)
for i in range(batch_size):
self._avg_value = (1. - self._alpha) * self._avg_value + self._alpha * value[i]

@property
def mean(self):
Expand Down Expand Up @@ -196,7 +200,10 @@ def update_stats(self, value):
value (np.ndarray): current data value to use for the update.
"""
self._avg_buffer.append(value)
value = self._array_backend.atleast_2d(value)
batch_size = len(value)
for i in range(batch_size):
self._avg_buffer.append(value[i])

@property
def mean(self):
Expand Down

0 comments on commit 27845c2

Please sign in to comment.