Skip to content

Commit

Permalink
Merge pull request #53 from lightly-ai/develop
Browse files Browse the repository at this point in the history
Pre-release 1.0.5 - Develop to Master
  • Loading branch information
IgorSusmelj authored Nov 27, 2020
2 parents 3d7cdd0 + 500c86b commit bbc4b7d
Show file tree
Hide file tree
Showing 18 changed files with 405 additions and 161 deletions.
59 changes: 56 additions & 3 deletions lightly/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" Lightly is a computer vision framework for self-supervised learning.
"""Lightly is a computer vision framework for self-supervised learning.
With Lightly you can train deep learning models using
self-supervision. This means, that you don't require
Expand All @@ -7,15 +7,68 @@
It is built on top of PyTorch and therefore fully compatible
with other frameworks such as Fast.ai.
For information about the command-line interace, see lightly.cli.
The framework is structured into the following modules:
- **api**:
The lightly.api module handles communication with the Lightly web-app.
- **cli**:
The lightly.cli module provides a command-line interface for training
self-supervised models and embedding images. Furthermore, the command-line
tool can be used to upload and download images from/to the Lightly web-app.
- **core**:
The lightly.core module offers one-liners for simple self-supervised learning.
- **data**:
The lightly.data module provides a dataset wrapper and collate functions. The
collate functions are in charge of the data augmentations which are crucial for
self-supervised learning.
- **embedding**:
The lightly.embedding module combines the self-supervised models with a dataloader,
optimizer, and loss function to provide a simple pytorch-lightning trainable.
- **loss**:
The lightly.loss module contains implementations of popular self-supervised training
loss functions.
- **models**:
The lightly.models module holds the implementation of the ResNet as well as self-
supervised methods. Currently implements:
- SimCLR
- MoCo
- **transforms**:
The lightly.transforms module implements custom data transforms. Currently implements:
- Gaussian Blur
- Random Rotation
- **utils**:
The lightly.utils package provides global utility methods.
The io module contains utility to save and load embeddings in a format which is
understood by the Lightly library.
"""

# Copyright (c) 2020. Lightly AG and its affiliates.
# All Rights Reserved

__name__ = 'lightly'
__version__ = '1.0.4'
__version__ = '1.0.5'


try:
Expand Down
14 changes: 13 additions & 1 deletion lightly/api/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import random

import numpy as np
import torchvision

from itertools import islice

Expand Down Expand Up @@ -338,6 +339,7 @@ def upload_images_from_folder(path_to_folder: str,
token: str,
max_workers: int = 8,
mode: str = 'thumbnails',
size: int = -1,
verbose: bool = True):
"""Uploads images from a directory to the Lightly cloud solution.
Expand All @@ -356,6 +358,12 @@ def upload_images_from_folder(path_to_folder: str,
mode:
One of [full, thumbnails, metadata]. Whether to upload thumbnails,
full images, or metadata only.
size:
Desired output size. If negative, default output size is used.
If size is a sequence like (h, w), output size will be matched to
this. If size is an int, smaller edge of the image will be matched
to this number. i.e, if height > width, then image will be rescaled
to (size * height / width, size).
Raises:
ValueError if dataset is too large.
Expand All @@ -364,7 +372,11 @@ def upload_images_from_folder(path_to_folder: str,
"""

dataset = LightlyDataset(from_folder=path_to_folder)
transform = None
if isinstance(size, tuple) or size > 0:
transform = torchvision.transforms.Resize(size)

dataset = LightlyDataset(from_folder=path_to_folder, transform=transform)
upload_dataset(
dataset,
dataset_id,
Expand Down
8 changes: 8 additions & 0 deletions lightly/cli/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ token: '' # User access token to the platform.
dataset_id: '' # Identifier of the dataset on the platform
upload: 'full' # Whether to upload full images, thumbnails only, or metadata only.
# Must be one of ['full', 'thumbnails', 'none']
resize: -1 # Allow resizing of the images before uploading, usage =-1, =x, =[x,y]
embedding_name: 'default' # Name of the embedding to be used on the platform.
emb_upload_bsz: 32 # Number of embeddings which are uploaded in a single batch.
tag_name: 'initial-tag' # Name of the requested tag on the Lightly platform.
Expand Down Expand Up @@ -74,6 +75,13 @@ trainer:
max_epochs: 100 # Number of epochs to train for.
precision: 32 # If set to 16, will use half-precision.

# checkpoint_callback namespace: Modify the checkpoint callback
checkpoint_callback:
save_last: True # Whether to save the checkpoint from the last epoch.
save_top_k: 1 # Save the top k checkpoints.
dirpath: # Where to store the checkpoints (empty field resolves to None).
# If not set, checkpoints are stored in the hydra output dir.

# seed
seed: 1

Expand Down
6 changes: 2 additions & 4 deletions lightly/cli/embed_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,9 @@ def _embed_cli(cfg, is_cli_call=True):
checkpoint, map_location=device
)['state_dict']

model = ResNetSimCLR(**cfg['model']).to(device)
if state_dict is not None:
model = ResNetSimCLR.from_state_dict(state_dict, **cfg['model'])
model = model.to(device)
else:
model = ResNetSimCLR(**cfg['model']).to(device)
model.load_from_state_dict(state_dict)

encoder = SelfSupervisedEmbedding(model, None, None, None)
embeddings, labels, filenames = encoder.embed(dataloader, device=device)
Expand Down
5 changes: 4 additions & 1 deletion lightly/cli/lightly_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ def _lightly_cli(cfg, is_cli_call=True):

cfg['loader']['shuffle'] = True
cfg['loader']['drop_last'] = True
checkpoint = _train_cli(cfg, is_cli_call)
if cfg['trainer']['max_epochs'] > 0:
checkpoint = _train_cli(cfg, is_cli_call)
else:
checkpoint = ''

cfg['loader']['shuffle'] = False
cfg['loader']['drop_last'] = False
Expand Down
6 changes: 3 additions & 3 deletions lightly/cli/train_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,9 @@ def _train_cli(cfg, is_cli_call=True):
)['state_dict']

# load model
model = ResNetSimCLR(**cfg['model'])
if state_dict is not None:
model = ResNetSimCLR.from_state_dict(state_dict, **cfg['model'])
else:
model = ResNetSimCLR(**cfg['model'])
model.load_from_state_dict(state_dict)

criterion = NTXentLoss(**cfg['criterion'])
optimizer = torch.optim.SGD(model.parameters(), **cfg['optimizer'])
Expand All @@ -107,6 +106,7 @@ def _train_cli(cfg, is_cli_call=True):
collate_fn=collate_fn)

encoder = SelfSupervisedEmbedding(model, criterion, optimizer, dataloader)
encoder.init_checkpoint_callback(**cfg['checkpoint_callback'])
encoder = encoder.train_embedding(**cfg['trainer'])

print('Best model is stored at: %s' % (encoder.checkpoint))
Expand Down
7 changes: 6 additions & 1 deletion lightly/cli/upload_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def _upload_cli(cfg, is_cli_call=True):
dataset_id = cfg['dataset_id']
token = cfg['token']

size = cfg['resize']
if not isinstance(size, int):
size = tuple(size)

if not token or not dataset_id:
print('Please specify your access token and dataset id.')
print('For help, try: lightly-upload --help')
Expand All @@ -36,7 +40,8 @@ def _upload_cli(cfg, is_cli_call=True):
if input_dir:
mode = cfg['upload']
try:
upload_images_from_folder(input_dir, dataset_id, token, mode=mode)
upload_images_from_folder(
input_dir, dataset_id, token, mode=mode, size=size)
except (ValueError, ConnectionRefusedError) as error:
msg = f'Error: {error}'
print(msg)
Expand Down
16 changes: 12 additions & 4 deletions lightly/data/_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# All Rights Reserved

import os
import av
from PIL import Image
from torchvision import datasets

Expand Down Expand Up @@ -235,12 +234,21 @@ def get_filename(self, index):
i = len(self.offsets) - 1
while (self.offsets[i] > index):
i = i - 1


# get filename of the video file
filename = self.videos[i]
filename = os.path.relpath(filename, self.root)

# get video format and video name
splits = filename.split('.')
video_format = splits[-1]
video_name = '.'.join(splits[:-1])
timestamp = float(self.video_timestamps[i][index - self.offsets[i]])
return '%s-%.8fs-%s.png' % (video_name, timestamp, video_format)

# get frame number
frame_number = index - self.offsets[i]
if i < len(self.offsets) - 1:
n_frames = self.offsets[i+1] - self.offsets[i]
else:
n_frames = self.__len__() - self.offsets[i]

return f'{video_name}-{frame_number:0{len(str(n_frames))}}-{video_format}.png'
27 changes: 12 additions & 15 deletions lightly/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,7 @@

from lightly.data._helpers import _load_dataset
from lightly.data._helpers import DatasetFolder

try:
from lightly.data._video import VideoDataset
VIDEO_DATASET_AVAILABLE = True
except Exception:
VIDEO_DATASET_AVAILABLE = False

from lightly.data._video import VideoDataset

class LightlyDataset(data.Dataset):
"""Provides a uniform data interface for the embedding models.
Expand Down Expand Up @@ -74,7 +68,7 @@ def __init__(self,

def dump_image(self,
output_dir: str,
filename: str,
index: int,
format: Union[str, None] = None):
"""Saves a single image to the output directory.
Expand All @@ -85,14 +79,14 @@ def dump_image(self,
Args:
output_dir:
Output directory where the image is stored.
filename:
Filename of the image to store.
index:
Index of the image to store.
format:
Image format.
"""
index = self.get_filenames().index(filename)
image, _ = self.dataset[index]
filename = self._get_filename_by_index(index)

source = os.path.join(self.root_folder, filename)
target = os.path.join(output_dir, filename)
Expand Down Expand Up @@ -141,11 +135,14 @@ def dump(self,

# get all filenames
if filenames is None:
filenames = self.get_filenames()
indices = [i for i in range(self.__len__())]
else:
indices = \
[i for i, f in enumerate(self.get_filenames()) if f in filenames]

# dump images
for filename in filenames:
self.dump_image(output_dir, filename, format=format)
for index in indices:
self.dump_image(output_dir, index, format=format)

def get_filenames(self) -> List[str]:
"""Returns all filenames in the dataset.
Expand All @@ -166,7 +163,7 @@ def _get_filename_by_index(self, index) -> str:
elif isinstance(self.dataset, DatasetFolder):
full_path = self.dataset.samples[index][0]
return os.path.relpath(full_path, self.root_folder)
elif VIDEO_DATASET_AVAILABLE and isinstance(self.dataset, VideoDataset):
elif isinstance(self.dataset, VideoDataset):
return self.dataset.get_filename(index)
else:
return str(index)
Expand Down
53 changes: 46 additions & 7 deletions lightly/embedding/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,11 @@ def __init__(self,
self.dataloader = dataloader
self.scheduler = scheduler
self.checkpoint = None
# create custom model checkpoint and set attributes
self.checkpoint_callback = CustomModelCheckpoint()
self.checkpoint_callback.save_last = True
self.checkpoint_callback.save_top_k = 1
self.checkpoint_callback.monitor = 'loss'
self.cwd = os.getcwd()

self.checkpoint_callback = None
self.init_checkpoint_callback()

def forward(self, x):
return self.model(x)

Expand Down Expand Up @@ -81,8 +79,22 @@ def train_embedding(self, **kwargs):
A trained encoder, ready for embedding datasets.
"""
trainer = pl.Trainer(**kwargs,
checkpoint_callback=self.checkpoint_callback)
# backwards compatability for old pytorch-lightning versions:
# they changed the way checkpoint callbacks are passed in v1.0.3
# -> do a simple version check
# TODO: remove when incrementing minimum requirement for pl
pl_version = [int(v) for v in pl.__version__.split('.')]
ok_version = [1, 0, 4]
deprecated_checkpoint_callback = \
all([pl_v >= ok_v for pl_v, ok_v in zip(pl_version, ok_version)])

if deprecated_checkpoint_callback:
trainer = pl.Trainer(**kwargs,
callbacks=[self.checkpoint_callback])
else:
trainer = pl.Trainer(**kwargs,
checkpoint_callback=self.checkpoint_callback)

trainer.fit(self)

self.checkpoint = self.checkpoint_callback.best_model_path
Expand Down Expand Up @@ -127,3 +139,30 @@ def embed(self, *args, **kwargs):
"""
raise NotImplementedError()

def init_checkpoint_callback(self,
save_last=True,
save_top_k=1,
monitor='loss',
dirpath=None):
"""Initializes the checkpoint callback.
Args:
save_last:
Whether or not to save the checkpoint of the last epoch.
save_top_k:
Save the top_k model checkpoints.
monitor:
Which quantity to monitor.
dirpath:
Where to save the checkpoint.
"""
# initialize custom model checkpoint
self.checkpoint_callback = CustomModelCheckpoint()
self.checkpoint_callback.save_last = save_last
self.checkpoint_callback.save_top_k = save_top_k
self.checkpoint_callback.monitor = monitor

dirpath = self.cwd if dirpath is None else dirpath
self.checkpoint_callback.dirpath = dirpath
12 changes: 0 additions & 12 deletions lightly/models/_helpers.py

This file was deleted.

Loading

0 comments on commit bbc4b7d

Please sign in to comment.