This repository contains the code used to produce BlockSwap (paper) .
For a network composed of N stacked blocks, BlockSwap (uniformly) randomly suggests lists of N possible convolution alternatives based on a parameter budget. It ranks the samples using Fisher potential as a proxy for trained accuracy and then returns the best one:
Install the requirements via anaconda:
conda env create -f environment.yml
checkpoints/
is used to save trained modelsgenotypes/
is used to store.csv
files that contain network configurations chosen by BlockSwap. We have also included the exact models from the paper for reference.models/
contains PyTorch definitions for all of the models and blocktypes that we usedmodels/blocks.py
is where all of the block substitutions live
utils.py
contains useful operations that are used throughout the repository. It also includes random configuration sampling code.one_shot_fisher
is the function used to get the Fisher potential of a given network
model_generator.py
ranks random configurations at a given parameter goaltrain.py
can train your selected network
The general outline for using this code is as follows:
- Train your original network (the network you would like to compress, also called the teacher)
- Generate and rank possible student networks for a given parameter budget
- Train the highest ranking student network
These steps are illustrate below:
1. Train your original network
python train.py teacher -t wrn_40_2 --wrn_depth 40 --wrn_width 2 --data_loc='<path-to-data>' --GPU 0
2. Generate and rank possible student networks for a given parameter budget
Then you can generate student networks for a parameter goal of your choice:
python model_generator.py --data_loc='<path-to-data>' --param_goal $p
This will save a .csv
file containing the generated architecture.
3. Train the highest ranking student network
Train the network using the following command:
python train.py student -t wrn_40_2 -s wrn_40_2_<genotype-num> --wrn_depth 40 --wrn_width 2 --data_loc='<path-to-data>' --GPU 0 --from_genotype './genotypes/<genotype-num>.csv'
The following repos provided basis and inspiration for this work:
https://github.com/szagoruyko/attention-transfer
https://github.com/kuangliu/pytorch-cifar
https://github.com/xternalz/WideResNet-pytorch
https://github.com/ShichenLiu/CondenseNet
If you find this work helpful, please consider citing us:
@inproceedings{
Turner2020BlockSwap:,
title={BlockSwap: Fisher-guided Block Substitution for Network Compression on a Budget},
author={Jack Turner and Elliot J. Crowley and Michael O'Boyle and Amos Storkey and Gavin Gray},
booktitle={International Conference on Learning Representations},
year={2020},
url={https://openreview.net/forum?id=SklkDkSFPB}
}