Skip to content

Commit

Permalink
add DynamicOptimizer wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioCarta committed Apr 17, 2024
1 parent 847af49 commit 1128d43
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 14 deletions.
2 changes: 1 addition & 1 deletion avalanche/models/dynamic_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def __init__(self, auto_adapt=True):
super().__init__()
self._auto_adapt = auto_adapt

def pre_adapt(self, experience):
def pre_adapt(self, agent, experience):
"""
Calls self.adaptation recursively accross
the hierarchy of pytorch module childrens
Expand Down
53 changes: 52 additions & 1 deletion avalanche/models/dynamic_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@

import numpy as np

from avalanche._annotations import deprecated
from avalanche._annotations import deprecated, experimental
from avalanche.benchmarks import CLExperience
from avalanche.core import Adaptable, Agent

colors = {
"END": "\033[0m",
Expand All @@ -31,6 +33,55 @@
colors[None] = colors["END"]


@experimental("New dynamic optimizers. The API may slightly change in the next versions.")
class DynamicOptimizer(Adaptable):
"""Avalanche dynamic optimizer.
In continual learning, many model architecture may change over time (e.g.
adding new units to the classifier). This is handled by the `DynamicModule`.
Changing the model's architecture requires updating the optimizer to add
the new parameters in the optimizer's state.
This class provides a simple wrapper to handle the optimizer
update via the `Adaptable` Protocol.
This class provides direct methods only for `zero_grad` and `step` to support
basic training functionality. Other methods of the base optimizers must be
called by using the base optimizer directly (e.g. `self.optim.add_param_group`).
NOTE: the model must be adapted *before* calling this method.
To ensure this, ensure that the optimizer is added to the agent state
after the model:
.. code-block::
agent.model = model
agent.optimizer = optimizer
# ... more init code
# pre_adapt will call the pre_adapt methods in order,
# first model.pre_adapt, then optimizer.pre_adapt
agent.pre_adapt(experience)
"""
def __init__(self, optim):
self.optim = optim

def zero_grad(self):
self.optim.zero_grad()

def step(self):
self.optim.step()

def pre_adapt(self, agent: Agent, exp: CLExperience):
"""Adapt the optimizer before training on the current experience."""
update_optimizer(
self.optim,
new_params=dict(agent.model.named_parameters()),
optimized_params=dict(agent.model.named_parameters())
)


def _map_optimized_params(optimizer, parameters, old_params=None):
"""
Establishes a mapping between a list of named parameters and the parameters
Expand Down
22 changes: 10 additions & 12 deletions examples/updatable_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from avalanche.evaluation.plot_utils import plot_metric_matrix
from avalanche.models import SimpleMLP, IncrementalClassifier
from avalanche.models.dynamic_modules import avalanche_model_adaptation
from avalanche.models.dynamic_optimizers import update_optimizer
from avalanche.models.dynamic_optimizers import update_optimizer, DynamicOptimizer
from avalanche.training import ReservoirSamplingBuffer, LearningWithoutForgetting
from avalanche.training.losses import MaskedCrossEntropy

Expand Down Expand Up @@ -55,10 +55,11 @@ def train_experience(agent_state, exp, epochs=10):

@torch.no_grad()
def my_eval(model, stream, metrics):
# eval also becomes simpler. Notice how in Avalanche it's harder to check whether
# we are evaluating a single exp. or the whole stream.
# Now we evaluate each stream with a separate function call
"""Evaluate `model` on `stream` computing `metrics`.
Returns a dictionary {metric_name: list-of-results}.
"""
model.eval()
res = {uo.__class__.__name__: [] for uo in metrics}
for exp in stream:
[uo.reset() for uo in metrics]
Expand Down Expand Up @@ -96,16 +97,13 @@ def my_eval(model, stream, metrics):
agent.add_pre_hooks(lambda a, e: avalanche_model_adaptation(a.model, e))

# optimizer and scheduler
agent.opt = SGD(agent.model.parameters(), lr=0.001)
agent.scheduler = ExponentialLR(agent.opt, gamma=0.999)
# we use a hook to update the optimizer before each experience.
# we have update the optimizer before each experience.
# This is needed because the model's parameters may change if you are using
# a dynamic model.
agent.add_pre_hooks(
lambda a, e: update_optimizer(
a.opt, new_params={}, optimized_params=dict(a.model.named_parameters())
)
)
opt = SGD(agent.model.parameters(), lr=0.001)
agent.opt = DynamicOptimizer(opt)
agent.scheduler = ExponentialLR(opt, gamma=0.999)
# we use a hook to call the scheduler.
# we update the lr scheduler after each experience (not every epoch!)
agent.add_post_hooks(lambda a, e: a.scheduler.step())

Expand Down
34 changes: 34 additions & 0 deletions tests/models/test_dynamic_optimizers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest

from torch.optim import SGD
from torch.utils.data import DataLoader

from avalanche.core import Agent
from avalanche.models import SimpleMLP, as_multitask
from avalanche.models.dynamic_optimizers import DynamicOptimizer
from avalanche.training import MaskedCrossEntropy
from tests.unit_tests_utils import get_fast_benchmark


class TestDynamicOptimizers(unittest.TestCase):
def test_dynamic_optimizer(self):
bm = get_fast_benchmark(use_task_labels=True)
agent = Agent()
agent.loss = MaskedCrossEntropy()
agent.model = as_multitask(SimpleMLP(input_size=6), "classifier")
opt = SGD(agent.model.parameters(), lr=0.001)
agent.opt = DynamicOptimizer(opt)

for exp in bm.train_stream:
agent.model.train()
data = exp.dataset.train()
agent.pre_adapt(exp)
for ep in range(1):
dl = DataLoader(data, batch_size=32, shuffle=True)
for x, y, t in dl:
agent.opt.zero_grad()
yp = agent.model(x, t)
l = agent.loss(yp, y)
l.backward()
agent.opt.step()
agent.post_adapt(exp)

0 comments on commit 1128d43

Please sign in to comment.