Official PyTorch implementation of CVPR 2022 paper Self-Taught Metric Learning without Labels (STML). A standard embedding network trained with STML achieves SOTA performance on unsupervised metric learning and sometimes even beats supervised learning models. This repository provides source code of unsupervised metric learning experiments on three datasets (CUB-200-2011, Cars-196, Stanford Online Products).
- [23.04.22] Code for training and testing an embedding network with SSL Methods (MoCo, BYOL, MeanShift).
- [23.04.22] Fix bug about "same device as the indexed tensor" in Evaluation.
- Contextualized semantic similarity between a pair of data is estimated on the embedding space of the teacher network.
- The semantic similarity is then used as a pseudo label, and the student network is optimized by relaxed contrastive loss with KL divergence.
- The teacher network is updated by an exponential moving average of the student.
- Our model with 128 embedding dimensions outperforms all previous arts using higher embedding dimensions and sometimes surpasses supervised learning methods.
- Python3
- PyTorch (> 1.0)
- NumPy
- tqdm
- wandb
- AdamP
-
Download four public benchmarks for deep metric learning
- CUB-200-2011
- Cars-196 (Img, Annotation)
- Stanford Online Products (Link)
-
Extract the tgz or zip file into
./data/
(Exceptionally, for Cars-196, put the files in a./data/cars196
)
- Train an embedding network with GoogLeNet (d=512) using STML
python3 code/main.py --gpu-id 0 \
--model googlenet \
--embedding_size 512 \
--optimizer adamp \
--lr 1e-4 \
--dataset cub \
--view 2 \
--sigma 3 \
--delta 1 \
--num_neighbors 5
- Train an embedding network with BN-Inception (d=512) using STML
python3 code/main.py --gpu-id 0 \
--model bn_inception \
--embedding_size 512 \
--optimizer adamp \
--lr 1e-4 \
--dataset cub \
--view 2 \
--sigma 3 \
--delta 1 \
--num_neighbors 5 \
--bn-freeze 1
- Train an embedding network with GoogLeNet (d= 512) using STML
python3 code/main.py --gpu-id 0 \
--model googlenet \
--embedding_size 512 \
--optimizer adamp \
--lr 1e-4 \
--dataset cars \
--view 2 \
--sigma 3 \
--delta 1 \
--num_neighbors 5
- Train an embedding network with BN-Inception (d=512) using STML
python3 code/main.py --gpu-id 0 \
--model bn_inception \
--embedding_size 512 \
--optimizer adamp \
--lr 1e-4 \
--dataset cars \
--view 2 \
--sigma 3 \
--delta 1 \
--num_neighbors 5 \
--bn-freeze 1
- Train an embedding network with GoogLeNet (d= 512) using STML
python3 code/main.py --gpu-id 0 \
--model googlenet \
--embedding_size 512 \
--optimizer adamp \
--lr 1e-4 \
--dataset SOP \
--view 2 \
--sigma 3 \
--delta 0.9 \
--num_neighbors 2 \
--momentum 0.9 \
--weight-decay 1e-2 \
--emb-lr 1e-2
- Train an embedding network with BN-Inception (d=512) using STML
python3 code/main.py --gpu-id 0 \
--model bn_inception \
--embedding_size 512 \
--optimizer adamp \
--lr 1e-4 \
--dataset SOP \
--view 2 \
--sigma 3 \
--delta 0.9 \
--num_neighbors 2 \
--momentum 0.9 \
--weight-decay 1e-2 \
--emb-lr 1e-2 \
--bn_freeze 1
- Train an embedding network with ResNet18 (d=128) using STML
python3 code/main.py --gpu-id 0 \
--model resnet18 \
--embedding_size 128 \
--optimizer adamp \
--lr 5e-4 \
--dataset SOP \
--view 2 \
--sigma 3 \
--delta 0.9 \
--num_neighbors 2 \
--momentum 0.9 \
--pretrained false \
--weight-decay 1e-2 \
--batch-size 120 \
--epoch 180 \
--fix_lr true
- Train an embedding network with GoogLeNet (d=512) using MoCo
python3 code/main_SSL.py --gpu-id 0 \
--model googlenet \
--embedding_size 512 \
--optimizer adamp \
--lr 1e-4 \
--dataset cub \
--view 2 \
--method moco \
--memory-size 9600
- Train an embedding network with GoogLeNet (d=512) using MeanShift
- Note that BYOL is same with MeanShift using topk 1
python3 code/main_SSL.py --gpu-id 0 \
--model googlenet \
--embedding_size 512 \
--optimizer adamp \
--lr 1e-4 \
--dataset cub \
--view 2 \
--method meanshift \
--memory-size 9600 \
--topk 5
Our source code is modified and adapted on these great repositories:
- Embedding Transfer with Label Relaxation for Improved Metric Learning
- Proxy Anchor Loss for Deep Metric Learning
- No Fuss Distance Metric Learning using Proxies
- PyTorch Metric learning
- MoCo: Momentum Contrast for Unsupervised Visual Representation Learning
- Mean Shift for Self-Supervised Learning
If you use this method or this code in your research, please cite as:
@inproceedings{kim2022self,
title={Self-Taught Metric Learning without Labels},
author={Kim, Sungyeon and Kim, Dongwon and Cho, Minsu and Kwak, Suha},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2022}
}