Generalization properties of learning algorithms and Sharpness Aware Minimization based on Minimum Sharpness
Description: One of the biggest challenges in deep learning is an understanding generalization. Sharpness is one of the indicators of generalization properties that perform well in practice. Moreover, sharpness aware minimization (SAM) is a new state-of-the-art technique based on simultaneously minimizing both loss and sharpness. In this work, we investigate the recently introduced notion of sharpness, known as minimum sharpness. We investigate its correlation with the generalization gap, by considering many different optimizers and SAM. Finally, we tackle the question of adaptivity of learning algorithms as that also has an impact on generalization, and investigate how the choice of optimizer influences sharpness.
- The folder
/checkpoints
contains results from current runs that you do within the repository once you fork it. - The folder
/checkpoints_test
contains some precomputed checkpoints for the model illustrated in theTrainingSample.ipynb
notebook. The structure of the checkpoints folder is:/DATASET
(FashionMNIST/CIFAR10)/MODEL_ARCHITECTURE
(SimpleBatch/MiddleBatch/ComplexBatch)/epochX
(50/100/150/200): We train all models up to 200 epochs and save the checkpoints every 50 epochs./converged
: Whenever the model converges (loss is lower than tolerance set), we save again the checkpoints.
- The folder
/data
should be empty by default and will be populated with data when training the models. - The folder
/results
contains .csv files with results for each of the datasets. - The folder
/optimizers
contains implementation of AdaBound, AdaShift and SAM in torch, collected from external sources. - The folder
/sharpness
contains the approximate calculation of the Hessian and sharpness - Notebooks
TrainingSample.ipynb
,DataAnalysis.ipynb
illustrate our work and are presented below. - The files within the repository represent:
models.py
- Contains the architecture of the models we considered.main.py
- Able to run the trainings and computation for a given configuration. A configuration is given by dataset, model architecture, optimizerhelpers.py
- Various utils used for training, testing, computation, data preprocessing.
We require installation of Python. The needed libraries are stated in requirements.txt
, to install them run: pip install -r requirements.txt
,
or pip3 install -r requirements.txt
(Python 3).
- To explore our work, we encourage you to look through our notebooks:
TrainingSample.ipynb
allows you to train a model and compute sharpness for a given dataset, architecture and optimizerDataAnalysis.ipynb
loads all the results from trainings and prepares the plots. If results are missing, it requires you to download all the existing checkpoints from training to extract the results, which might take a longer time, or alternatively retrain the models. All the existing checkpoints are available at: https://drive.google.com/drive/folders/10LuJDXzP6P_xH-z66Kh4KaWPfR1s0-t9?usp=sharing However, due to limited size on Github, we have not added them here (>30GB).
- For running a model for a given configuration, we also offer a runnable Python file:
python main.py train $dataset $model $optimizer $use_sam $load_existing
allows you to train a model- dataset should be
CIFAR10
orFashionMNIST
- model should be
SimpleBatch
,MiddleBatch
orComplexBatch
- optimizer should be
SGD
,PHB
,Adagrad
,Adam
,AdaShift
,AdaBound
- use_sam should be 0 (do not use) or 1 (use)
- load_existing is not used here, it can be 0 or 1.
- dataset should be
python main.py compute_sharpness $dataset $model $optimizer $use_sam $load_existing
allows you to compute sharpness for the given model. All params stay the same, except for:- load_existing should be 1 if you trained the model already and would like to load from file, 0 otherwise
python main.py plot $dataset $model $optimizer $use_sam $load_existing
allows you to visualize the computations
- To automatically run all the configuration (dataset, optimizer, arhitecture), we offer you some shell scripts which can be run as:
- To train all models, run:
chmod +x train_all
and./train_all
- To compute sharpness for all trained models, run:
chmod +x compute_sharpness_all
and./compute_sharpness_all.sh
- Jana Vuckovic: [email protected]
- Miguel-Angel Sanchez Ndoye: [email protected]
- Irina Bejan: [email protected]