-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathil2m_plugin.py
182 lines (155 loc) · 7.54 KB
/
il2m_plugin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
from typing import Optional
from packaging.version import parse
import torch
import numpy as np
from avalanche.training.plugins.strategy_plugin import SupervisedPlugin
from avalanche.training.storage_policy import ExperienceBalancedBuffer
from avalanche.benchmarks.utils.data_loader import ReplayDataLoader
class IL2MPlugin(SupervisedPlugin):
"""
Class Incremental Learning With Dual Memory (IL2M) plugin.
Technique introduced in:
Belouadah, E. and Popescu, A. "IL2M: Class Incremental Learning With Dual
Memory." Proceedings of the IEEE/CVF Conference on Computer Vision and
Pattern Recognition. 2019.
Implementation based on FACIL, as in:
https://github.com/mmasana/FACIL/blob/master/src/approach/il2m.py
"""
def __init__(
self,
mem_size: int = 2000,
batch_size: Optional[int] = None,
batch_size_mem: Optional[int] = None,
storage_policy: Optional["ExemplarsBuffer"] = None,
):
"""
:param mem_size: replay buffer size.
:param batch_size: the size of the data batch. If set to `None`, it
will be set equal to the strategy's batch size.
:param batch_size_mem: the size of the memory batch. If its value is set
to `None` (the default value), it will be automatically set equal to
the data batch size.
:param storage_policy: The policy that controls how to add new exemplars
in memory.
"""
super().__init__()
self.mem_size = mem_size
self.batch_size = batch_size
self.batch_size_mem = batch_size_mem
if storage_policy is not None: # Use other storage policy
self.storage_policy = storage_policy
assert storage_policy.max_size == self.mem_size
else: # Default
self.storage_policy = ExperienceBalancedBuffer(
max_size=self.mem_size, adaptive_size=True
)
# to store statistics for the classes as learned in the current incremental state
self.current_classes_means = []
# to store statistics for past classes as learned in the incremental state in which they were first seen
self.init_classes_means = []
# to store statistics for model confidence in different states (i.e. avg top-1 pred scores)
self.models_confidence = []
# to store the mapping between classes and the incremental state in which they were first seen
self.classes2exp = []
# total number of classes that will be seen
self.n_classes = 0
def before_training_exp(
self,
strategy: "SupervisedTemplate",
num_workers: int = 0,
shuffle: bool = True,
drop_last: bool = False,
**kwargs
):
if len(self.init_classes_means) == 0:
self.n_classes = len(strategy.experience.classes_seen_so_far) + \
len(strategy.experience.future_classes)
self.init_classes_means = [0 for _ in range(self.n_classes)]
self.classes2exp = [-1 for _ in range(self.n_classes)]
if len(self.storage_policy.buffer) == 0:
# first experience. We don't use the buffer, no need to change
# the dataloader.
return
batch_size = self.batch_size
if batch_size is None:
batch_size = strategy.train_mb_size
batch_size_mem = self.batch_size_mem
if batch_size_mem is None:
batch_size_mem = strategy.train_mb_size
assert strategy.adapted_dataset is not None
other_dataloader_args = dict()
if "ffcv_args" in kwargs:
other_dataloader_args["ffcv_args"] = kwargs["ffcv_args"]
if "persistent_workers" in kwargs:
if parse(torch.__version__) >= parse("1.7.0"):
other_dataloader_args["persistent_workers"] = kwargs[
"persistent_workers"
]
strategy.dataloader = ReplayDataLoader(
strategy.adapted_dataset,
self.storage_policy.buffer,
oversample_small_tasks=True,
batch_size=batch_size,
batch_size_mem=batch_size_mem,
num_workers=num_workers,
shuffle=shuffle,
drop_last=drop_last,
**other_dataloader_args
)
def after_training_exp(self, strategy: "SupervisedTemplate", **kwargs):
experience = strategy.experience
self.current_classes_means = [0 for _ in range(self.n_classes)]
classes_counts = [0 for _ in range(self.n_classes)]
self.models_confidence.append(0)
models_counts = 0
# compute the mean prediction scores that will be used to rectify scores in subsequent incremental states
with torch.no_grad():
strategy.model.eval()
for inputs, targets, _ in strategy.dataloader:
inputs, targets = inputs.to(strategy.device), targets.to(strategy.device)
outputs = strategy.model(inputs.to(strategy.device))
scores = outputs.data.cpu().numpy()
for i in range(len(targets)):
target = targets[i].item()
classes_counts[target] += 1
if target in experience.previous_classes:
# compute the mean prediction scores for past classes of the current state
self.current_classes_means[target] += scores[i, target]
else:
# compute the mean prediction scores for the new classes of the current state
self.init_classes_means[target] += scores[i, target]
# compute the mean top scores for the new classes of the current state
self.models_confidence[-1] += np.max(scores[i, ])
models_counts += 1
# normalize by corresponding number of samples
for cls in experience.previous_classes:
self.current_classes_means[cls] /= classes_counts[cls]
for cls in experience.classes_in_this_experience:
self.init_classes_means[cls] /= classes_counts[cls]
self.models_confidence[-1] /= models_counts
# store the mapping between classes and the incremental state in which they are first seen
for cls in experience.classes_in_this_experience:
self.classes2exp[cls] = experience.current_experience
# update the buffer of exemplars
self.storage_policy.post_adapt(strategy, strategy.experience)
def after_eval_forward(self, strategy: "SupervisedTemplate", **kwargs):
old_classes = strategy.experience.previous_classes
new_classes = strategy.experience.classes_in_this_experience
if not old_classes:
return
outputs = strategy.mb_output
targets = strategy.mbatch[1]
# rectify predicted scores (Eq. 1 in the paper)
for i in range(len(targets)):
# if the top-1 class predicted by the network is a new one, rectify the score
if outputs[i].argmax().item() in new_classes:
for cls in old_classes:
o_exp = self.classes2exp[cls]
if self.current_classes_means[cls] == 0: # when evaluation is done before training
continue
outputs[i, cls] *= (self.init_classes_means[cls] / self.current_classes_means[cls]) * \
(self.models_confidence[-1] / self.models_confidence[o_exp])
# otherwise, rectification is not done because an old class is directly predicted
__all__ = [
"IL2MPlugin",
]