Skip to content

Latest commit

 

History

History
37 lines (29 loc) · 1.77 KB

README.md

File metadata and controls

37 lines (29 loc) · 1.77 KB

Meta-learning Symmetries by Reparameterization (MSR)

Code and weights corresponding to Meta-learning Symmetries by Reparameterization, found at this arXiv link.

Installation and requirements

All experiments were run in a conda environment with python 3.7.5, using pytorch. The conda environment we used is exported in environment.yml, though it likely contains more packages than are strictly necessary.

Synthetic experiments

Generate data

The file generate_synthetic_data.py contains the data generation code. You must specify the name of the problem we wish to generate data for. Options are:

  • rank1: 1-D data generated by a convolution layer
  • rank2: 1-D data generated by a locally connected layer, with rank 2 weight factorization.
  • rank5: 1-D data generated by a locally connected layer, with rank 5 weight factorization.
  • 2d_rot8: 2-D data that is equivariant to 45-degree rotations.
  • 2d_rot8_flip: 2-D data that is equivariant to 45-degree rotations and flips.

Example:

python generate_synthetic_data.py --problem rank1

Train

The file train_synthetic.py contains training and evaluation code. Specifying the argument model controls whether to use MSR, plain MAML, or something else.

Some examples are given below:

python train_synthetic.py --problem rank1 --model share_fc  # MSR+FC on rank1 problem.
python train_synthetic.py --problem rank1 --model fc  # MAML+FC on rank1 problem.
python train_synthetic.py --problem rank1 --model conv  # MAML+Conv model on rank1 problem.
python train_synthetic.py --problem 2d_rot8 --model share_conv  # MSR+Conv on 2d_rot8 problem.
python train_synthetic.py --problem 2d_rot8 --model conv  # MAML+Conv on 2d_rot8 problem.

Augmented-(Omniglot/Miniimagenet)

WIP