Skip to content

Commit

Permalink
add custom weights argument for exp_reduce in MetricCollector
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioCarta committed Apr 17, 2024
1 parent 639454e commit 1808f20
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 8 deletions.
14 changes: 11 additions & 3 deletions avalanche/evaluation/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def update(self, res, *, stream=None):
else:
self.metrics_res[k] = [v]

def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None):
def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None, weights=None):
"""Returns a metric value given its name and aggregation method.
:param name: name of the metric.
Expand All @@ -91,14 +91,20 @@ def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None):
- None (default) does not use any aggregation
- `sample_mean` is an average weighted by the number of samples in each experience
- `experience_mean` is an experience average.
- 'weighted_sum' is a weighted sum of the experiences using the `weights` argument.
:param stream: stream that was used to compute the metric. This is
needed to build the full metric name if the get was called with a
stream name and if `exp_reduce == sample_mean` to get the number
of samples from each experience.
:param weights: weights for each experience when `exp_reduce == 'weighted_sum`.
:return: aggregated metric value.
"""
assert time_reduce in {None, "last", "mean"}
assert exp_reduce in {None, "sample_mean", "experience_mean"}
assert exp_reduce in {None, "sample_mean", "experience_mean", "weighted_sum"}
if exp_reduce == "weighted_sum":
assert weights is not None, "You should set the `weights` argument when `exp_reduce == 'weighted_sum'`."
else:
assert weights is None, "Can't use the `weights` argument when `exp_reduce != 'weighted_sum'`"

if stream is not None:
name = f"{stream.name}/{name}"
Expand All @@ -122,11 +128,13 @@ def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None):
mvals = mvals.sum(axis=1) # weighted avg across exp.
elif exp_reduce == "experience_mean":
mvals = mvals.mean(axis=1) # avg across exp.
elif exp_reduce == "weighted_sum":
weights = np.array(weights)[None, :]
mvals = (mvals * weights).sum(axis=1)
else:
raise ValueError("BUG. It should never get here.")

if time_reduce is None:

pass # nothing to do here
elif time_reduce == "last":
mvals = mvals[-1] # last timestep
Expand Down
8 changes: 3 additions & 5 deletions examples/updatable_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@

def train_experience(agent_state, exp, epochs=10):
agent_state.model.train()
data = (
exp.dataset.train()
) # avalanche datasets have train/eval modes to switch augmentations

# avalanche datasets have train/eval modes to switch augmentations
data = exp.dataset.train()
agent_state.pre_adapt(exp) # update objects and call pre_hooks
for ep in range(epochs):
if len(agent_state.replay.buffer) > 0:
Expand Down Expand Up @@ -64,7 +62,7 @@ def my_eval(model, stream, metrics):
res = {uo.__class__.__name__: [] for uo in metrics}
for exp in stream:
[uo.reset() for uo in metrics]
dl = DataLoader(exp.dataset, batch_size=512, num_workers=8)
dl = DataLoader(exp.dataset.eval(), batch_size=512, num_workers=8)
for x, y, _ in dl:
x, y = x.cuda(), y.cuda()
yp = model(x)
Expand Down
26 changes: 26 additions & 0 deletions tests/evaluation/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ def __len__(self):
np.testing.assert_array_almost_equal(
v, [(1 + 3) / 2, (5 + 7) / 2, (11 + 13) / 2]
)
v = mc.get(
"FakeMetric",
time_reduce=None,
exp_reduce="weighted_sum",
weights=[1, 2],
stream=fake_stream
)
np.testing.assert_array_almost_equal(
v, [(1 + 3*2), (5 + 7*2), (11 + 13*2)]
)

# time = "last"
v = mc.get(
Expand All @@ -72,6 +82,14 @@ def __len__(self):
stream=fake_stream,
)
self.assertAlmostEqual(v, (11 + 13) / 2)
v = mc.get(
"FakeMetric",
time_reduce="last",
exp_reduce="weighted_sum",
stream=fake_stream,
weights=[1, 2]
)
self.assertAlmostEqual(v, 11 + 13*2)

# time_reduce = "mean"
v = mc.get(
Expand All @@ -92,6 +110,14 @@ def __len__(self):
stream=fake_stream,
)
self.assertAlmostEqual(v, ((1 + 5 + 11) / 3 + (3 + 7 + 13) / 3) / 2)
v = mc.get(
"FakeMetric",
time_reduce="mean",
exp_reduce="weighted_sum",
stream=fake_stream,
weights=[1, 2]
)
self.assertAlmostEqual(v, ((1 + 3*2) + (5 + 7*2) + (11 + 13*2)) / 3)


if __name__ == "__main__":
Expand Down

0 comments on commit 1808f20

Please sign in to comment.