This repository is the official implementation of Compression-aware Continual Learning using SVD.
Checkout our arxiv version of the paper :
To install requirements you need to create an anaconda environment using the following code snippet:
conda create --name <env> --file anaconda_requirements.txt
pip install -r requirements.txt
All datasets except notMNIST and miniImageNet are downloaded from the torchvision.datasets
- notMNIST is by default downloaded from Adversarial Continual Learning
- Please download miniImageNet from and unzip the train.pkl and test.pkl into to a new folder data/mini-imagenet
We provide pretrained models for CIFAR-100, miniImageNet, 5-sequence dataset. To evaluate trained model use:
Model name | Accuracy | Model Size(MB) |
CACL_Final | 86.58% | 8.53 |
Model name | Accuracy | Model Size(MB) |
CACL_Final | 70.10% | 13.03 |
Model name | Accuracy | Model Size(MB) |
CACL_Final | 91.56% | 1.48 |
To train the model(s) from scratch, run the following scripts. The scripts contain hyper-parameter details used to obtain the results in this paper:
python -e <pruning_intensity> --model_name <Net_SVD/vgg16_bn_cifar100_SVD> --model_type <customnet_SVD/vgg_16_bn> --exp_name <exp_name> --first_split_size <classes per task> --other_split_size <classes per task> --train_aug --schedule <lrdropepochs> --batch_size <64/128> --dataset <CIFAR100/miniImageNet/multidataset> --force_out_dim 0 --sparse_wt <sparsity weight> --benchmark <fixatesrandomseed> --rand_split_order --repeat 3
This repository is built on top of GT-RIPL / Continual-Learning-Benchmark which includes baseline results for the recent continual learning algorithms.