From 1808f20df4f0d8d64e2d401e7ba01a032ce02933 Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Wed, 17 Apr 2024 10:16:47 +0200 Subject: [PATCH] add custom weights argument for exp_reduce in MetricCollector --- avalanche/evaluation/collector.py | 14 +++++++++++--- examples/updatable_objects.py | 8 +++----- tests/evaluation/test_collector.py | 26 ++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/avalanche/evaluation/collector.py b/avalanche/evaluation/collector.py index 30825b000..475459704 100644 --- a/avalanche/evaluation/collector.py +++ b/avalanche/evaluation/collector.py @@ -79,7 +79,7 @@ def update(self, res, *, stream=None): else: self.metrics_res[k] = [v] - def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None): + def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None, weights=None): """Returns a metric value given its name and aggregation method. :param name: name of the metric. @@ -91,14 +91,20 @@ def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None): - None (default) does not use any aggregation - `sample_mean` is an average weighted by the number of samples in each experience - `experience_mean` is an experience average. + - 'weighted_sum' is a weighted sum of the experiences using the `weights` argument. :param stream: stream that was used to compute the metric. This is needed to build the full metric name if the get was called with a stream name and if `exp_reduce == sample_mean` to get the number of samples from each experience. + :param weights: weights for each experience when `exp_reduce == 'weighted_sum`. :return: aggregated metric value. """ assert time_reduce in {None, "last", "mean"} - assert exp_reduce in {None, "sample_mean", "experience_mean"} + assert exp_reduce in {None, "sample_mean", "experience_mean", "weighted_sum"} + if exp_reduce == "weighted_sum": + assert weights is not None, "You should set the `weights` argument when `exp_reduce == 'weighted_sum'`." + else: + assert weights is None, "Can't use the `weights` argument when `exp_reduce != 'weighted_sum'`" if stream is not None: name = f"{stream.name}/{name}" @@ -122,11 +128,13 @@ def get(self, name, *, time_reduce=None, exp_reduce=None, stream=None): mvals = mvals.sum(axis=1) # weighted avg across exp. elif exp_reduce == "experience_mean": mvals = mvals.mean(axis=1) # avg across exp. + elif exp_reduce == "weighted_sum": + weights = np.array(weights)[None, :] + mvals = (mvals * weights).sum(axis=1) else: raise ValueError("BUG. It should never get here.") if time_reduce is None: - pass # nothing to do here elif time_reduce == "last": mvals = mvals[-1] # last timestep diff --git a/examples/updatable_objects.py b/examples/updatable_objects.py index 481a0bcbe..12532f729 100644 --- a/examples/updatable_objects.py +++ b/examples/updatable_objects.py @@ -25,10 +25,8 @@ def train_experience(agent_state, exp, epochs=10): agent_state.model.train() - data = ( - exp.dataset.train() - ) # avalanche datasets have train/eval modes to switch augmentations - + # avalanche datasets have train/eval modes to switch augmentations + data = exp.dataset.train() agent_state.pre_adapt(exp) # update objects and call pre_hooks for ep in range(epochs): if len(agent_state.replay.buffer) > 0: @@ -64,7 +62,7 @@ def my_eval(model, stream, metrics): res = {uo.__class__.__name__: [] for uo in metrics} for exp in stream: [uo.reset() for uo in metrics] - dl = DataLoader(exp.dataset, batch_size=512, num_workers=8) + dl = DataLoader(exp.dataset.eval(), batch_size=512, num_workers=8) for x, y, _ in dl: x, y = x.cuda(), y.cuda() yp = model(x) diff --git a/tests/evaluation/test_collector.py b/tests/evaluation/test_collector.py index 482ccbcb3..e6146fb99 100644 --- a/tests/evaluation/test_collector.py +++ b/tests/evaluation/test_collector.py @@ -52,6 +52,16 @@ def __len__(self): np.testing.assert_array_almost_equal( v, [(1 + 3) / 2, (5 + 7) / 2, (11 + 13) / 2] ) + v = mc.get( + "FakeMetric", + time_reduce=None, + exp_reduce="weighted_sum", + weights=[1, 2], + stream=fake_stream + ) + np.testing.assert_array_almost_equal( + v, [(1 + 3*2), (5 + 7*2), (11 + 13*2)] + ) # time = "last" v = mc.get( @@ -72,6 +82,14 @@ def __len__(self): stream=fake_stream, ) self.assertAlmostEqual(v, (11 + 13) / 2) + v = mc.get( + "FakeMetric", + time_reduce="last", + exp_reduce="weighted_sum", + stream=fake_stream, + weights=[1, 2] + ) + self.assertAlmostEqual(v, 11 + 13*2) # time_reduce = "mean" v = mc.get( @@ -92,6 +110,14 @@ def __len__(self): stream=fake_stream, ) self.assertAlmostEqual(v, ((1 + 5 + 11) / 3 + (3 + 7 + 13) / 3) / 2) + v = mc.get( + "FakeMetric", + time_reduce="mean", + exp_reduce="weighted_sum", + stream=fake_stream, + weights=[1, 2] + ) + self.assertAlmostEqual(v, ((1 + 3*2) + (5 + 7*2) + (11 + 13*2)) / 3) if __name__ == "__main__":