Skip to content

SimSiam and Refactoring of Models and Dataset

Compare
Choose a tag to compare
@philippmwirth philippmwirth released this 11 Jan 11:30
· 994 commits to master since this release
b8a1988

SimSiam and Refactoring of Models and Dataset

This release contains breaking changes. The models SimCLR and MoCo, the LightlyDataset, and the BaseCollateFunction were refactored. These changes were necessary to make the code base better understandable.

SimSiam (@busycalibrating)

An implementation of the SimSiam self-supervised framework is introduced. It relies on a siamese network architecture and aims to maximize similarity between two augmentations of one image.

Refactoring: LightlyDataset

The LightlyDataset is refactored such that the constructor now always expects an input directory input_dir which indicates where the images are stored. To use a LightlyDataset with any PyTorch dataset, the class method LightlyDataset.from_torch_dataset can be used.

1.0.7 (incompatible)

>>> dataset = LightlyDataset(from_folder='path/to/data')
>>>
>>> dataset = LightlyDataset(root='./', name='cifar10', download=True)

1.0.8

>>> dataset = LightlyDataset(input_dir='path/to/data')
>>>
>>> torch_dataset = torchvision.datasets.CIFAR10(root='./', download=True)
>>> dataset = LightlyDataset.from_torch_dataset(torch_dataset)

Refactoring: BaseCollateFunction

The BaseCollateFunction now returns a tuple of augmented image batches along with the labels and filenames (aug0, aug1), labels, filenames where aug0 and aug1 are both of shape bsz x channels x H x W.

Refactoring: SimCLR, MoCo and NTXentLoss

In accordance with the changes of the BaseCollateFunction, SimCLR and MoCo will expect the augmented images seperately now instead of as a single batch. Similarly, the NTXentLoss now requires a separate batch of representations as inputs.

1.0.7 (incompatible)

>>> # batch size is 128
>>> batch, labels, filenames = next(iter(dataloader))
>>> batch.shape
torch.Size([256, 3, 32, 32]) 
>>> # number of features is 64
>>> y = simclr(batch)
>>> y.shape
torch.Size([256, 64])
>>> loss = ntx_ent_loss(y)

1.0.8

>>> # batch size is 128
>>> (batch0, batch1), labels, filenames = next(iter(dataloader))
>>> batch0.shape
torch.Size([128, 3, 32, 32])    
>>> batch1.shape
torch.Size([128, 3, 32, 32]) 
>>> # number of features is 64
>>> y0, y1 = simclr(batch0, batch1)
>>> y0.shape
torch.Size([128, 64])
>>> y1.shape
torch.Size([128, 64])
>>> loss = ntx_ent_loss(y0, y1)

Documentation Updates

A tutorial about how to use the SimSiam model is added along with some minor changes and improvements.

Minor Changes

Private functions are hidden from autocompletion.

Models