From 1128d4321ad53c38f569d9da795a32088da4e359 Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Wed, 17 Apr 2024 11:44:50 +0200 Subject: [PATCH] add DynamicOptimizer wrapper --- avalanche/models/dynamic_modules.py | 2 +- avalanche/models/dynamic_optimizers.py | 53 ++++++++++++++++++++++++- examples/updatable_objects.py | 22 +++++----- tests/models/test_dynamic_optimizers.py | 34 ++++++++++++++++ 4 files changed, 97 insertions(+), 14 deletions(-) create mode 100644 tests/models/test_dynamic_optimizers.py diff --git a/avalanche/models/dynamic_modules.py b/avalanche/models/dynamic_modules.py index b2466dff9..9a43cda54 100644 --- a/avalanche/models/dynamic_modules.py +++ b/avalanche/models/dynamic_modules.py @@ -73,7 +73,7 @@ def __init__(self, auto_adapt=True): super().__init__() self._auto_adapt = auto_adapt - def pre_adapt(self, experience): + def pre_adapt(self, agent, experience): """ Calls self.adaptation recursively accross the hierarchy of pytorch module childrens diff --git a/avalanche/models/dynamic_optimizers.py b/avalanche/models/dynamic_optimizers.py index 33d6eb72d..a1a97523c 100644 --- a/avalanche/models/dynamic_optimizers.py +++ b/avalanche/models/dynamic_optimizers.py @@ -18,7 +18,9 @@ import numpy as np -from avalanche._annotations import deprecated +from avalanche._annotations import deprecated, experimental +from avalanche.benchmarks import CLExperience +from avalanche.core import Adaptable, Agent colors = { "END": "\033[0m", @@ -31,6 +33,55 @@ colors[None] = colors["END"] +@experimental("New dynamic optimizers. The API may slightly change in the next versions.") +class DynamicOptimizer(Adaptable): + """Avalanche dynamic optimizer. + + In continual learning, many model architecture may change over time (e.g. + adding new units to the classifier). This is handled by the `DynamicModule`. + Changing the model's architecture requires updating the optimizer to add + the new parameters in the optimizer's state. + + This class provides a simple wrapper to handle the optimizer + update via the `Adaptable` Protocol. + + This class provides direct methods only for `zero_grad` and `step` to support + basic training functionality. Other methods of the base optimizers must be + called by using the base optimizer directly (e.g. `self.optim.add_param_group`). + + NOTE: the model must be adapted *before* calling this method. + To ensure this, ensure that the optimizer is added to the agent state + after the model: + + .. code-block:: + + agent.model = model + agent.optimizer = optimizer + + # ... more init code + + # pre_adapt will call the pre_adapt methods in order, + # first model.pre_adapt, then optimizer.pre_adapt + agent.pre_adapt(experience) + """ + def __init__(self, optim): + self.optim = optim + + def zero_grad(self): + self.optim.zero_grad() + + def step(self): + self.optim.step() + + def pre_adapt(self, agent: Agent, exp: CLExperience): + """Adapt the optimizer before training on the current experience.""" + update_optimizer( + self.optim, + new_params=dict(agent.model.named_parameters()), + optimized_params=dict(agent.model.named_parameters()) + ) + + def _map_optimized_params(optimizer, parameters, old_params=None): """ Establishes a mapping between a list of named parameters and the parameters diff --git a/examples/updatable_objects.py b/examples/updatable_objects.py index 12532f729..5313561d6 100644 --- a/examples/updatable_objects.py +++ b/examples/updatable_objects.py @@ -18,7 +18,7 @@ from avalanche.evaluation.plot_utils import plot_metric_matrix from avalanche.models import SimpleMLP, IncrementalClassifier from avalanche.models.dynamic_modules import avalanche_model_adaptation -from avalanche.models.dynamic_optimizers import update_optimizer +from avalanche.models.dynamic_optimizers import update_optimizer, DynamicOptimizer from avalanche.training import ReservoirSamplingBuffer, LearningWithoutForgetting from avalanche.training.losses import MaskedCrossEntropy @@ -55,10 +55,11 @@ def train_experience(agent_state, exp, epochs=10): @torch.no_grad() def my_eval(model, stream, metrics): - # eval also becomes simpler. Notice how in Avalanche it's harder to check whether - # we are evaluating a single exp. or the whole stream. - # Now we evaluate each stream with a separate function call + """Evaluate `model` on `stream` computing `metrics`. + Returns a dictionary {metric_name: list-of-results}. + """ + model.eval() res = {uo.__class__.__name__: [] for uo in metrics} for exp in stream: [uo.reset() for uo in metrics] @@ -96,16 +97,13 @@ def my_eval(model, stream, metrics): agent.add_pre_hooks(lambda a, e: avalanche_model_adaptation(a.model, e)) # optimizer and scheduler - agent.opt = SGD(agent.model.parameters(), lr=0.001) - agent.scheduler = ExponentialLR(agent.opt, gamma=0.999) - # we use a hook to update the optimizer before each experience. + # we have update the optimizer before each experience. # This is needed because the model's parameters may change if you are using # a dynamic model. - agent.add_pre_hooks( - lambda a, e: update_optimizer( - a.opt, new_params={}, optimized_params=dict(a.model.named_parameters()) - ) - ) + opt = SGD(agent.model.parameters(), lr=0.001) + agent.opt = DynamicOptimizer(opt) + agent.scheduler = ExponentialLR(opt, gamma=0.999) + # we use a hook to call the scheduler. # we update the lr scheduler after each experience (not every epoch!) agent.add_post_hooks(lambda a, e: a.scheduler.step()) diff --git a/tests/models/test_dynamic_optimizers.py b/tests/models/test_dynamic_optimizers.py new file mode 100644 index 000000000..0012749ed --- /dev/null +++ b/tests/models/test_dynamic_optimizers.py @@ -0,0 +1,34 @@ +import unittest + +from torch.optim import SGD +from torch.utils.data import DataLoader + +from avalanche.core import Agent +from avalanche.models import SimpleMLP, as_multitask +from avalanche.models.dynamic_optimizers import DynamicOptimizer +from avalanche.training import MaskedCrossEntropy +from tests.unit_tests_utils import get_fast_benchmark + + +class TestDynamicOptimizers(unittest.TestCase): + def test_dynamic_optimizer(self): + bm = get_fast_benchmark(use_task_labels=True) + agent = Agent() + agent.loss = MaskedCrossEntropy() + agent.model = as_multitask(SimpleMLP(input_size=6), "classifier") + opt = SGD(agent.model.parameters(), lr=0.001) + agent.opt = DynamicOptimizer(opt) + + for exp in bm.train_stream: + agent.model.train() + data = exp.dataset.train() + agent.pre_adapt(exp) + for ep in range(1): + dl = DataLoader(data, batch_size=32, shuffle=True) + for x, y, t in dl: + agent.opt.zero_grad() + yp = agent.model(x, t) + l = agent.loss(yp, y) + l.backward() + agent.opt.step() + agent.post_adapt(exp)