-
Notifications
You must be signed in to change notification settings - Fork 100
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Alpha release Co-authored-by: Mikel Menta Garde <[email protected]> Co-authored-by: Bartlomiej Twardowski <[email protected]> Co-authored-by: Xialei Liu <[email protected]>
- Loading branch information
1 parent
e0563c7
commit 74a826f
Showing
42 changed files
with
2,178 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Framework for Analysis of Class-Incremental Learning | ||
Run the code with: | ||
``` | ||
python3 -u src/main_incremental.py | ||
``` | ||
followed by general options: | ||
|
||
* `--gpu`: index of GPU to run the experiment on (default=0) | ||
* `--results-path`: path where results are stored (default='../results') | ||
* `--exp-name`: experiment name (default=None) | ||
* `--seed`: random seed (default=0) | ||
* `--save-models`: save trained models (default=False) | ||
* `--last-layer-analysis`: plot last layer analysis (default=False) | ||
* `--no-cudnn-deterministic`: disable CUDNN deterministic (default=False) | ||
|
||
and specific options for each of the code parts (corresponding to folders): | ||
|
||
* `--approach`: learning approach used (default='finetuning') [[more](approaches/README.md)] | ||
* `--datasets`: dataset or datasets used (default=['cifar100']) [[more](datasets/README.md)] | ||
* `--network`: network architecture used (default='resnet32') [[more](networks/README.md)] | ||
* `--log`: loggers used (default='disk') [[more](loggers/README.md)] | ||
|
||
go to each of their respective readme to see all available options for each of them. | ||
|
||
## Approaches | ||
Initially, the approaches included in the framework correspond to the ones presented in | ||
_**Class-incremental learning: survey and performance evaluation**_ (preprint , 2020). The regularization-based | ||
approaches are EWC, MAS, PathInt, LwF, LwM and DMC (green). The rehearsal approaches are iCaRL, EEIL and RWalk (blue). | ||
The bias-correction approaches are IL2M, BiC and LUCIR (orange). | ||
|
||
![alt text](../docs/_static/cil_survey_approaches.png "Survey approaches") | ||
|
||
More approaches will be included in the future. To learn more about them refer to the readme in | ||
[src/approaches](approaches). | ||
|
||
## Datasets | ||
To learn about the dataset management refer to the readme in [src/datasets](datasets). | ||
|
||
## Networks | ||
To learn about the different torchvision and custom networks refer to the readme in [src/networks](networks). | ||
|
||
## GridSearch | ||
We implement the option to use a realistic grid search for hyperparameters which only takes into account the task at | ||
hand, without access to previous or future information not available in the incremental learning scenario. It | ||
corresponds to the one introduced in _**Class-incremental learning: survey and performance evaluation**_. The GridSearch | ||
can be applied by using: | ||
|
||
* `--gridsearch-tasks`: number of tasks to apply GridSearch (-1: all tasks) (default=-1) | ||
|
||
which we recommend to set to the total number of tasks of the experiment for a more realistic setting of the correct | ||
learning rate and possible forgetting-intransigence trade-off. However, since this has a considerable extra | ||
computational cost, it can also be set to the first 3 tasks, which would fix those hyperparameters for the remaining | ||
tasks. Other GridSearch options include: | ||
|
||
* `--gridsearch-config`: configuration file for GridSearch options (default='gridsearch_config') [[more](gridsearch_config.py)] | ||
* `--gridsearch-acc-drop-thr`: GridSearch accuracy drop threshold (default=0.2) | ||
* `--gridsearch-hparam-decay`: GridSearch hyperparameter decay (default=0.5) | ||
* `--gridsearch-max-num-searches`: GridSearch maximum number of hyperparameter search (default=7) | ||
|
||
## Utils | ||
We have some utility functions added into `utils.py`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import importlib | ||
from copy import deepcopy | ||
from argparse import ArgumentParser | ||
|
||
import utils | ||
|
||
|
||
class GridSearch: | ||
"""Basic class for implementing hyperparameter grid search""" | ||
|
||
def __init__(self, appr_ft, seed, gs_config='gridsearch_config', acc_drop_thr=0.2, hparam_decay=0.5, | ||
max_num_searches=7): | ||
self.seed = seed | ||
GridSearchConfig = getattr(importlib.import_module(name=gs_config), 'GridSearchConfig') | ||
self.appr_ft = appr_ft | ||
self.gs_config = GridSearchConfig() | ||
self.acc_drop_thr = acc_drop_thr | ||
self.hparam_decay = hparam_decay | ||
self.max_num_searches = max_num_searches | ||
self.lr_first = 1.0 | ||
|
||
@staticmethod | ||
def extra_parser(args): | ||
"""Returns a parser containing the GridSearch specific parameters""" | ||
parser = ArgumentParser() | ||
# Configuration file with a GridSearchConfig class with all necessary args | ||
parser.add_argument('--gridsearch-config', type=str, default='gridsearch_config', required=False, | ||
help='Configuration file for GridSearch options (default=%(default)s)') | ||
# Accuracy threshold drop below which the search stops for that phase | ||
parser.add_argument('--gridsearch-acc-drop-thr', default=0.2, type=float, required=False, | ||
help='GridSearch accuracy drop threshold (default=%(default)f)') | ||
# Value at which hyperparameters decay | ||
parser.add_argument('--gridsearch-hparam-decay', default=0.5, type=float, required=False, | ||
help='GridSearch hyperparameter decay (default=%(default)f)') | ||
# Maximum number of searched before the search stops for that phase | ||
parser.add_argument('--gridsearch-max-num-searches', default=7, type=int, required=False, | ||
help='GridSearch maximum number of hyperparameter search (default=%(default)f)') | ||
return parser.parse_known_args(args) | ||
|
||
def search_lr(self, model, t, trn_loader, val_loader): | ||
"""Search for accuracy and best LR on finetuning""" | ||
best_ft_acc = 0.0 | ||
best_ft_lr = 0.0 | ||
|
||
# Get general parameters and fix the ones with only one value | ||
gen_params = self.gs_config.get_params('general') | ||
for k, v in gen_params.items(): | ||
if not isinstance(v, list): | ||
setattr(self.appr_ft, k, v) | ||
if t > 0: | ||
# LR for search are 'lr_searches' largest LR below 'lr_first' | ||
list_lr = [lr for lr in gen_params['lr'] if lr < self.lr_first][:gen_params['lr_searches'][0]] | ||
else: | ||
# For first task, try larger LR range | ||
list_lr = gen_params['lr_first'] | ||
|
||
# Iterate through the other variable parameters | ||
for curr_lr in list_lr: | ||
utils.seed_everything(seed=self.seed) | ||
self.appr_ft.model = deepcopy(model) | ||
self.appr_ft.lr = curr_lr | ||
self.appr_ft.train(t, trn_loader, val_loader) | ||
_, ft_acc_taw, _ = self.appr_ft.eval(t, val_loader) | ||
if ft_acc_taw > best_ft_acc: | ||
best_ft_acc = ft_acc_taw | ||
best_ft_lr = curr_lr | ||
print('Current best LR: ' + str(best_ft_lr)) | ||
self.gs_config.current_lr = best_ft_lr | ||
print('Current best acc: {:5.1f}'.format(best_ft_acc * 100)) | ||
# After first task, keep LR used | ||
if t == 0: | ||
self.lr_first = best_ft_lr | ||
|
||
return best_ft_acc, best_ft_lr | ||
|
||
def search_tradeoff(self, appr_name, appr, t, trn_loader, val_loader, best_ft_acc): | ||
"""Search for less-forgetting tradeoff with minimum accuracy loss""" | ||
best_tradeoff = None | ||
tradeoff_name = None | ||
|
||
# Get general parameters and fix all the ones that have only one option | ||
appr_params = self.gs_config.get_params(appr_name) | ||
for k, v in appr_params.items(): | ||
if isinstance(v, list): | ||
# get tradeoff name as the only one with multiple values | ||
tradeoff_name = k | ||
else: | ||
# Any other hyperparameters are fixed | ||
setattr(appr, k, v) | ||
|
||
# If there is no tradeoff, no need to gridsearch more | ||
if tradeoff_name is not None and t > 0: | ||
# get starting value for trade-off hyperparameter | ||
best_tradeoff = appr_params[tradeoff_name][0] | ||
# iterate through decreasing trade-off values -- limit to `max_num_searches` searches | ||
num_searches = 0 | ||
while num_searches < self.max_num_searches: | ||
utils.seed_everything(seed=self.seed) | ||
# Make deepcopy of the appr without duplicating the logger | ||
appr_gs = type(appr)(deepcopy(appr.model), appr.device, exemplars_dataset=appr.exemplars_dataset) | ||
for attr, value in vars(appr).items(): | ||
if attr == 'logger': | ||
setattr(appr_gs, attr, value) | ||
else: | ||
setattr(appr_gs, attr, deepcopy(value)) | ||
|
||
# update tradeoff value | ||
setattr(appr_gs, tradeoff_name, best_tradeoff) | ||
# train this iteration | ||
appr_gs.train(t, trn_loader, val_loader) | ||
_, curr_acc, _ = appr_gs.eval(t, val_loader) | ||
print('Current acc: ' + str(curr_acc) + ' for ' + tradeoff_name + '=' + str(best_tradeoff)) | ||
# Check if accuracy is within acceptable threshold drop | ||
if curr_acc < ((1 - self.acc_drop_thr) * best_ft_acc): | ||
best_tradeoff = best_tradeoff * self.hparam_decay | ||
else: | ||
break | ||
num_searches += 1 | ||
else: | ||
print('There is no trade-off to gridsearch.') | ||
|
||
return best_tradeoff, tradeoff_name |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
class GridSearchConfig(): | ||
def __init__(self): | ||
self.params = { | ||
'general': { | ||
'lr_first': [5e-1, 1e-1, 5e-2], | ||
'lr': [1e-1, 5e-2, 1e-2, 5e-3, 1e-3], | ||
'lr_searches': [3], | ||
'lr_min': 1e-4, | ||
'lr_factor': 3, | ||
'lr_patience': 10, | ||
'clipping': 10000, | ||
'momentum': 0.9, | ||
'wd': 0.0002 | ||
}, | ||
'finetuning': { | ||
}, | ||
'freezing': { | ||
}, | ||
'joint': { | ||
}, | ||
'lwf': { | ||
'lamb': [10], | ||
'T': 2 | ||
}, | ||
'icarl': { | ||
'lamb': [4] | ||
}, | ||
'dmc': { | ||
'aux_dataset': 'imagenet_32_reduced', | ||
'aux_batch_size': 128 | ||
}, | ||
'il2m': { | ||
}, | ||
'eeil': { | ||
'lamb': [10], | ||
'T': 2, | ||
'lr_finetuning_factor': 0.1, | ||
'nepochs_finetuning': 40, | ||
'noise_grad': False | ||
}, | ||
'bic': { | ||
'T': 2, | ||
'val_percentage': 0.1, | ||
'bias_epochs': 200 | ||
}, | ||
'lucir': { | ||
'lamda_base': [10], | ||
'lamda_mr': 1.0, | ||
'dist': 0.5, | ||
'K': 2 | ||
}, | ||
'lwm': { | ||
'beta': [2], | ||
'gamma': 1.0 | ||
}, | ||
'ewc': { | ||
'lamb': [10000] | ||
}, | ||
'mas': { | ||
'lamb': [400] | ||
}, | ||
'path_integral': { | ||
'lamb': [10], | ||
}, | ||
'r_walk': { | ||
'lamb': [20], | ||
}, | ||
} | ||
self.current_lr = self.params['general']['lr'][0] | ||
self.current_tradeoff = 0 | ||
|
||
def get_params(self, approach): | ||
return self.params[approach] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import torch | ||
import matplotlib | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
matplotlib.use('Agg') | ||
|
||
|
||
def last_layer_analysis(heads, task, taskcla, y_lim=False, sort_weights=False): | ||
"""Plot last layer weight and bias analysis""" | ||
print('Plotting last layer analysis...') | ||
num_classes = sum([x for (_, x) in taskcla]) | ||
weights, biases, indexes = [], [], [] | ||
class_id = 0 | ||
with torch.no_grad(): | ||
for t in range(task + 1): | ||
n_classes_t = taskcla[t][1] | ||
indexes.append(np.arange(class_id, class_id + n_classes_t)) | ||
if type(heads) == torch.nn.Linear: # Single head | ||
biases.append(heads.bias[class_id: class_id + n_classes_t].detach().cpu().numpy()) | ||
weights.append((heads.weight[class_id: class_id + n_classes_t] ** 2).sum(1).sqrt().detach().cpu().numpy()) | ||
else: # Multi-head | ||
weights.append((heads[t].weight ** 2).sum(1).sqrt().detach().cpu().numpy()) | ||
if type(heads[t]) == torch.nn.Linear: | ||
biases.append(heads[t].bias.detach().cpu().numpy()) | ||
else: | ||
biases.append(np.zeros(weights[-1].shape)) # For LUCIR | ||
class_id += n_classes_t | ||
|
||
# Figure weights | ||
f_weights = plt.figure(dpi=300) | ||
ax = f_weights.subplots(nrows=1, ncols=1) | ||
for i, (x, y) in enumerate(zip(indexes, weights), 0): | ||
if sort_weights: | ||
ax.bar(x, sorted(y, reverse=True), label="Task {}".format(i)) | ||
else: | ||
ax.bar(x, y, label="Task {}".format(i)) | ||
ax.set_xlabel("Classes", fontsize=11, fontfamily='serif') | ||
ax.set_ylabel("Weights L2-norm", fontsize=11, fontfamily='serif') | ||
if num_classes is not None: | ||
ax.set_xlim(0, num_classes) | ||
if y_lim: | ||
ax.set_ylim(0, 5) | ||
ax.legend(loc='upper left', fontsize='11') #, fontfamily='serif') | ||
|
||
# Figure biases | ||
f_biases = plt.figure(dpi=300) | ||
ax = f_biases.subplots(nrows=1, ncols=1) | ||
for i, (x, y) in enumerate(zip(indexes, biases), 0): | ||
if sort_weights: | ||
ax.bar(x, sorted(y, reverse=True), label="Task {}".format(i)) | ||
else: | ||
ax.bar(x, y, label="Task {}".format(i)) | ||
ax.set_xlabel("Classes", fontsize=11, fontfamily='serif') | ||
ax.set_ylabel("Bias values", fontsize=11, fontfamily='serif') | ||
if num_classes is not None: | ||
ax.set_xlim(0, num_classes) | ||
if y_lim: | ||
ax.set_ylim(-1.0, 1.0) | ||
ax.legend(loc='upper left', fontsize='11') #, fontfamily='serif') | ||
|
||
return f_weights, f_biases |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# Loggers | ||
|
||
We include a disk logger, which logs into files and folders in the disk. We also provide a tensorboard logger which | ||
provides a faster way of analysing a training process without need of further development. They can be specified with | ||
`--log` followed by `disk`, `tensorboard` or both. Custom loggers can be defined by inheriting the `ExperimentLogger` | ||
in [exp_logger.py](exp_logger.py). | ||
|
||
When enabled, both loggers will output everything in the path `[RESULTS_PATH]/[DATASETS]_[APPROACH]_[EXP_NAME]` or | ||
`[RESULTS_PATH]/[DATASETS]_[APPROACH]` if `--exp-name` is not set. | ||
|
||
## Disk logger | ||
The disk logger outputs the following file and folder structure: | ||
- **figures/**: folder where generated figures are logged. | ||
- **models/**: folder where model weight checkpoints are saved. | ||
- **results/**: folder containing the results. | ||
- **acc_tag**: task-agnostic accuracy table. | ||
- **acc_taw**: task-aware accuracy table. | ||
- **avg_acc_tag**: task-agnostic average accuracies. | ||
- **avg_acc_taw**: task-agnostic average accuracies. | ||
- **forg_tag**: task-agnostic forgetting table. | ||
- **forg_taw**: task-aware forgetting table. | ||
- **wavg_acc_tag**: task-agnostic average accuracies weighted according to the number of classes of each task. | ||
- **wavg_acc_taw**: task-aware average accuracies weighted according to the number of classes of each task. | ||
- **raw_log**: json file containing all the logged metrics easily read by many tools (e.g. `pandas`). | ||
- stdout: a copy from the standard output of the terminal. | ||
- stderr: a copy from the error output of the terminal. | ||
|
||
## TensorBoard logger | ||
The tensorboard logger outputs analogous metrics to the disk logger separated into different tabs according to the task | ||
and different graphs according to the data splits. | ||
|
||
Screenshot for a 10 task experiment, showing the last task plots: | ||
<p align="center"> | ||
<img src="/docs/_static/tb2.png" alt="Tensorboard Screenshot" width="920"/> | ||
</p> |
Oops, something went wrong.