-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathil2m_strategy.py
101 lines (90 loc) · 3.63 KB
/
il2m_strategy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from typing import Optional, Union, List, Callable
import torch
from torch.nn import Module
from torch.optim import Optimizer
from avalanche.training.templates import SupervisedTemplate
from avalanche.training.templates.strategy_mixin_protocol import CriterionType
from avalanche.training.plugins import SupervisedPlugin, EvaluationPlugin
from avalanche.training.plugins.evaluation import default_evaluator
from il2m_plugin import IL2MPlugin
class IL2M(SupervisedTemplate):
"""Class Incremental Learning With Dual Memory (IL2M) strategy.
See IL2M plugin for details.
This strategy does not use task identities.
"""
def __init__(
self,
*,
model: Module,
optimizer: Optimizer,
criterion: CriterionType,
mem_size: int = 2000,
mem_mb_size: Optional[int] = None,
train_mb_size: int = 1,
train_epochs: int = 1,
eval_mb_size: Optional[int] = None,
storage_policy: Optional["ExemplarsBuffer"] = None,
device: Union[str, torch.device] = "cpu",
plugins: Optional[List[SupervisedPlugin]] = None,
evaluator: Union[
EvaluationPlugin, Callable[[], EvaluationPlugin]
] = default_evaluator,
eval_every=-1,
peval_mode="epoch",
**base_kwargs
):
"""Init.
:param model: The model.
:param optimizer: The optimizer to use.
:param criterion: The loss criterion to use.
:param mem_size: Replay buffer size. Defaults to 2000.
:param mem_mb_size: The size of the memory batch. Defaults to None.
:param train_mb_size: The train minibatch size. Defaults to 1.
:param train_epochs: The number of training epochs. Defaults to 1.
:param eval_mb_size: The eval minibatch size. Defaults to 1.
:param storage_policy: The policy that controls how to add new exemplars
in memory. Defaults to None.
:param device: The device to use. Defaults to None (cpu).
:param plugins: Plugins to be added. Defaults to None.
:param evaluator: (optional) Instance of EvaluationPlugin for logging
and metric computations.
:param eval_every: The frequency of the calls to `eval` inside the
training loop. -1 disables the evaluation. 0 means `eval` is called
only at the end of the learning experience. Values >0 mean that
`eval` is called every `eval_every` epochs and at the end of the
learning experience. Defaults to -1.
:param peval_mode: one of {'experience', 'iteration'}. Decides whether
the periodic evaluation during training should execute every
`eval_every` experience or iterations. Default to 'experience'.
:param **base_kwargs: any additional
:class:`~avalanche.training.BaseTemplate` constructor arguments.
"""
# Instantiate plugin
il2m = IL2MPlugin(
mem_size=mem_size,
batch_size=train_mb_size,
batch_size_mem=mem_mb_size,
storage_policy=storage_policy
)
# Add plugin to the strategy
if plugins is None:
plugins = [il2m]
else:
plugins.append(il2m)
super().__init__(
model=model,
optimizer=optimizer,
criterion=criterion,
train_mb_size=train_mb_size,
train_epochs=train_epochs,
eval_mb_size=eval_mb_size,
device=device,
plugins=plugins,
evaluator=evaluator,
eval_every=eval_every,
peval_mode=peval_mode,
**base_kwargs
)
__all__ = [
"IL2M"
]