Skip to content

Official PyTorch Implementation of Self-Taught Metric Learning without Labels, CVPR 2022

License

Notifications You must be signed in to change notification settings

tjddus9597/STML-CVPR22

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

31 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Self-Taught Metric Learning without Labels
(CVPR 2022)


Official PyTorch implementation of CVPR 2022 paper Self-Taught Metric Learning without Labels (STML). A standard embedding network trained with STML achieves SOTA performance on unsupervised metric learning and sometimes even beats supervised learning models. This repository provides source code of unsupervised metric learning experiments on three datasets (CUB-200-2011, Cars-196, Stanford Online Products).

New Update

  • [23.04.22] Code for training and testing an embedding network with SSL Methods (MoCo, BYOL, MeanShift).
  • [23.04.22] Fix bug about "same device as the indexed tensor" in Evaluation.

Overview

Self-Taught Metric Learning

  1. Contextualized semantic similarity between a pair of data is estimated on the embedding space of the teacher network.
  2. The semantic similarity is then used as a pseudo label, and the student network is optimized by relaxed contrastive loss with KL divergence.
  3. The teacher network is updated by an exponential moving average of the student.

graph

Experimental Restuls

  • Our model with 128 embedding dimensions outperforms all previous arts using higher embedding dimensions and sometimes surpasses supervised learning methods.

graph

Requirements

  • Python3
  • PyTorch (> 1.0)
  • NumPy
  • tqdm
  • wandb
  • AdamP

Datasets

  1. Download four public benchmarks for deep metric learning

  2. Extract the tgz or zip file into ./data/ (Exceptionally, for Cars-196, put the files in a ./data/cars196)

Training Embedding Network

CUB-200-2011 (Unsupervised)

  • Train an embedding network with GoogLeNet (d=512) using STML
python3 code/main.py --gpu-id 0 \
                        --model googlenet \
                        --embedding_size 512 \
                        --optimizer adamp \
                        --lr 1e-4 \
                        --dataset cub \
                        --view 2 \
                        --sigma 3 \
                        --delta 1 \
                        --num_neighbors 5
  • Train an embedding network with BN-Inception (d=512) using STML
python3 code/main.py --gpu-id 0 \
                        --model bn_inception \
                        --embedding_size 512 \
                        --optimizer adamp \
                        --lr 1e-4 \
                        --dataset cub \
                        --view 2 \
                        --sigma 3 \
                        --delta 1 \
                        --num_neighbors 5 \
                        --bn-freeze 1

Cars-196 (Unsupervised)

  • Train an embedding network with GoogLeNet (d= 512) using STML
python3 code/main.py --gpu-id 0 \
                        --model googlenet \
                        --embedding_size 512 \
                        --optimizer adamp \
                        --lr 1e-4 \
                        --dataset cars \
                        --view 2 \
                        --sigma 3 \
                        --delta 1 \
                        --num_neighbors 5
  • Train an embedding network with BN-Inception (d=512) using STML
python3 code/main.py --gpu-id 0 \
                        --model bn_inception \
                        --embedding_size 512 \
                        --optimizer adamp \
                        --lr 1e-4 \
                        --dataset cars \
                        --view 2 \
                        --sigma 3 \
                        --delta 1 \
                        --num_neighbors 5 \
                        --bn-freeze 1

Stanford Online Products (Unsupervised)

  • Train an embedding network with GoogLeNet (d= 512) using STML
python3 code/main.py --gpu-id 0 \
                        --model googlenet \
                        --embedding_size 512 \
                        --optimizer adamp \
                        --lr 1e-4 \
                        --dataset SOP \
                        --view 2 \
                        --sigma 3 \
                        --delta 0.9 \
                        --num_neighbors 2 \
                        --momentum 0.9 \
                        --weight-decay 1e-2 \
                        --emb-lr 1e-2
  • Train an embedding network with BN-Inception (d=512) using STML
python3 code/main.py --gpu-id 0 \
                        --model bn_inception \
                        --embedding_size 512 \
                        --optimizer adamp \
                        --lr 1e-4 \
                        --dataset SOP \
                        --view 2 \
                        --sigma 3 \
                        --delta 0.9 \
                        --num_neighbors 2 \
                        --momentum 0.9 \
                        --weight-decay 1e-2 \
                        --emb-lr 1e-2 \
                        --bn_freeze 1

Stanford Online Products (Unsupervised & From Scratch)

  • Train an embedding network with ResNet18 (d=128) using STML
python3 code/main.py --gpu-id 0 \
                        --model resnet18 \
                        --embedding_size 128 \
                        --optimizer adamp \
                        --lr 5e-4 \
                        --dataset SOP \
                        --view 2 \
                        --sigma 3 \
                        --delta 0.9 \
                        --num_neighbors 2 \
                        --momentum 0.9 \
                        --pretrained false \
                        --weight-decay 1e-2 \
                        --batch-size 120 \
                        --epoch 180 \
                        --fix_lr true

Training Embedding Network using SSL Method

Example using MoCo

  • Train an embedding network with GoogLeNet (d=512) using MoCo
python3 code/main_SSL.py --gpu-id 0 \
                        --model googlenet \
                        --embedding_size 512 \
                        --optimizer adamp \
                        --lr 1e-4 \
                        --dataset cub \
                        --view 2 \
                        --method moco \
                        --memory-size 9600

Example using MeanShift

  • Train an embedding network with GoogLeNet (d=512) using MeanShift
  • Note that BYOL is same with MeanShift using topk 1
python3 code/main_SSL.py --gpu-id 0 \
                        --model googlenet \
                        --embedding_size 512 \
                        --optimizer adamp \
                        --lr 1e-4 \
                        --dataset cub \
                        --view 2 \
                        --method meanshift \
                        --memory-size 9600 \
                        --topk 5 

Acknowledgements

Our source code is modified and adapted on these great repositories:

Citation

If you use this method or this code in your research, please cite as:

@inproceedings{kim2022self,
  title={Self-Taught Metric Learning without Labels},
  author={Kim, Sungyeon and Kim, Dongwon and Cho, Minsu and Kwak, Suha},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
  year={2022}
}

About

Official PyTorch Implementation of Self-Taught Metric Learning without Labels, CVPR 2022

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages