Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SimCLR v1 #42

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions configs/simclrv1_pretext_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import os
import ml_collections

def get_wandb_configs() -> ml_collections.ConfigDict:
configs = ml_collections.ConfigDict()
configs.project = "ssl-study"
configs.entity = "wandb_fc"

return configs

def get_dataset_configs() -> ml_collections.ConfigDict:
configs = ml_collections.ConfigDict()
configs.image_height = 224 #default - 224
configs.image_width = 224 #default - 224
configs.channels = 3
configs.batch_size = 64
configs.num_classes = 200

return configs

def get_augment_configs() -> ml_collections.ConfigDict:
configs = ml_collections.ConfigDict()
configs.image_height = 224 #default - 224
configs.image_width = 224 #default - 224
configs.cropscale = (0.08, 1.0)
configs.cropratio = (0.75, 1.3333333333333333)
configs.jitterbrightness = 0.2
configs.jittercontrast = 0.2
configs.jittersaturation = 0.2
configs.jitterhue = 0.2
configs.gaussianblurlimit = (3, 7)
configs.gaussiansigmalimit = 0
configs.alwaysapply = False
configs.probability = 0.5

return configs

def get_bool_configs() -> ml_collections.ConfigDict:
configs = ml_collections.ConfigDict()
configs.apply_resize = True
configs.do_cache = False
configs.use_cosine_similarity = True

return configs

def get_model_configs() -> ml_collections.ConfigDict:
configs = ml_collections.ConfigDict()
configs.backbone = "resnet50"
configs.hidden1 = 256
configs.hidden2 = 128
configs.hidden3 = 50

return configs

def get_train_configs() -> ml_collections.ConfigDict:
configs = ml_collections.ConfigDict()
configs.epochs = 30
configs.temperature = 0.5
configs.s = 1

return configs

def get_learning_rate_configs() -> ml_collections.ConfigDict:
configs = ml_collections.ConfigDict()
configs.decay_steps = 1000
configs.initial_learning_rate = 0.1

def get_config() -> ml_collections.ConfigDict:
config = ml_collections.ConfigDict()
config.seed = 0
config.wandb_config = get_wandb_configs()
config.dataset_config = get_dataset_configs()
config.augmentation_config = get_augment_configs()
config.bool_config = get_bool_configs()
config.train_config = get_train_configs()
config.learning_rate_config = get_learning_rate_configs()
config.model_config = get_model_configs()

return config
57 changes: 57 additions & 0 deletions simclrv1_pretext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# General imports
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import glob
import wandb
from absl import app
from absl import flags
import numpy as np
import tensorflow as tf
from ml_collections.config_flags import config_flags

# Import modules
from ssl_study.data import download_dataset, preprocess_dataframe
from ssl_study.simclrv1.pretext_task.data import GetDataloader
from ssl_study.simclrv1.pretext_task.models import SimCLRv1Model
from ssl_study.simclrv1.pretext_task.pipeline import SimCLRv1Pipeline

FLAGS = flags.FLAGS
CONFIG = config_flags.DEFINE_config_file("config")

def main(_):
with wandb.init(
entity=CONFIG.value.wandb_config.entity,
project=CONFIG.value.wandb_config.project,
job_type='simclrv1_pretext',
config=CONFIG.value.to_dict(),
):
# Access all hyperparameter values through wandb.config
config = wandb.config
# Seed Everything
tf.random.set_seed(config.seed)

# Load the dataframes
inclass_df = download_dataset('in-class', 'unlabelled-dataset')

# Preprocess the DataFrames
inclass_paths = preprocess_dataframe(inclass_df, is_labelled=False)

# Build dataloaders
dataset = GetDataloader(config)
inclassloader = dataset.dataloader(inclass_paths)

# Model
tf.keras.backend.clear_session()
model = SimCLRv1Model(config).get_model(config.model_config["hidden1"], config.model_config["hidden2"], config.model_config["hidden3"])
model.summary()

# Build the pipeline
pipeline = SimCLRv1Pipeline(model, config)
optimizer = pipeline.get_optimizer
criterion = pipeline.get_criterion

epoch_wise_loss, resnet_simclr = pipeline.train_simclr(model, inclassloader, optimizer, criterion, temperature=config.train_config["temperature"], epochs=config.train_config["epochs"])


if __name__ == "__main__":
app.run(main)
4 changes: 2 additions & 2 deletions ssl_study/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .dataset import download_dataset, preprocess_dataset
from .dataset import download_dataset, preprocess_dataframe
from .dataloader import GetDataloader

__all__ = [
'download_dataset', 'preprocess_dataset', 'GetDataloader'
'download_dataset', 'preprocess_dataframe', 'GetDataloader'
]
30 changes: 19 additions & 11 deletions ssl_study/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def download_dataset(dataset_name: str,
data_df = pd.read_csv(save_at+'train.csv')
elif dataset_name == 'val' and os.path.exists(save_at+'valid.csv'):
data_df = pd.read_csv(save_at+'valid.csv')
elif dataset_name == 'in-class' and os.path.exists(save_at+'in-class.csv'):
data_df = pd.read_csv(save_at+'in-class.csv')
elif dataset_name == 'out-class' and os.path.exists(save_at+'out-class.csv'):
data_df = pd.read_csv(save_at+'out-class.csv')
else:
data_df = None
print('Downloading dataset...')
Expand Down Expand Up @@ -82,19 +86,23 @@ def download_dataset(dataset_name: str,
if dataset_name == 'val' and not os.path.exists(save_at+'valid.csv'):
data_df.to_csv(save_at+'valid.csv', index=False)

return data_df

if dataset_name == 'in-class' and not os.path.exists(save_at+'in-class.csv'):
data_df.to_csv(save_at+'in-class.csv', index=False)

def preprocess_dataset(df):
# TODO: take care of df without labels.
# Remove unnecessary columns
df = df.drop(['image_id', 'width', 'height'], axis=1)
assert len(df.columns) == 2
if dataset_name == 'out-class' and not os.path.exists(save_at+'out-class.csv'):
data_df.to_csv(save_at+'out-class.csv', index=False)

# Fix types
df[['label']] = df[['label']].apply(pd.to_numeric)
return data_df

def preprocess_dataframe(df, is_labelled=True):
df = df.drop(['image_id', 'width', 'height'], axis=1)
image_paths = df.image_path.values
labels = df.label.values

return image_paths, labels
if is_labelled:
assert len(df.columns) == 2
# Fix types
df[['label']] = df[['label']].apply(pd.to_numeric)
labels = df.label.values
return image_paths, labels
else:
return image_paths
Empty file.
6 changes: 6 additions & 0 deletions ssl_study/simclrv1/pretext_task/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .data_aug import Augment
from .dataloader import GetDataloader

__all__ = [
'Augment', 'GetDataloader'
]
55 changes: 55 additions & 0 deletions ssl_study/simclrv1/pretext_task/data/data_aug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from albumentations import augmentations
from albumentations.augmentations.transforms import ToGray
import numpy as np
import tensorflow as tf
from functools import partial
import albumentations as A

AUTOTUNE = tf.data.AUTOTUNE

class Augment():
def __init__(self, args):
self.args = args

def build_augmentation(self, image):
transform = A.Compose([
A.RandomResizedCrop(self.args.augmentation_config["image_height"],
self.args.augmentation_config["image_width"],
self.args.augmentation_config["cropscale"],
self.args.augmentation_config["cropratio"],
self.args.augmentation_config["probability"]),
A.HorizontalFlip(self.args.augmentation_config["probability"]),
A.ColorJitter(self.args.augmentation_config["jitterbrightness"],
self.args.augmentation_config["jittercontrast"],
self.args.augmentation_config["jittersaturation"],
self.args.augmentation_config["jitterhue"],
self.args.augmentation_config["alwaysapply"],
self.args.augmentation_config["probability"]),
A.ToGray(self.args.augmentation_config["probability"]),
A.GaussianBlur(self.args.augmentation_config["gaussianblurlimit"],
self.args.augmentation_config["gaussiansigmalimit"],
self.args.augmentation_config["alwaysapply"],
self.args.augmentation_config["probability"])
])
return transform

def augmentation(self, image):
aug_img = tf.numpy_function(func=self.aug_fn, inp=[image], Tout=tf.float32)
aug_img.set_shape((self.args.augmentation_config["image_height"],
self.args.augmentation_config["image_width"], 3))

aug_img = tf.image.resize(aug_img,
[self.args.augmentation_config["image_height"],
self.args.augmentation_config["image_width"]],
method='bicubic',
preserve_aspect_ratio=False)
aug_img = tf.clip_by_value(aug_img, 0.0, 1.0)

return aug_img

def aug_fn(self, image):
data = {"image":image}
aug_data = self.build_augmentation(**data)
aug_img = aug_data["image"]

return aug_img.astype(np.float32)
55 changes: 55 additions & 0 deletions ssl_study/simclrv1/pretext_task/data/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import numpy as np
import tensorflow as tf

from .data_aug import Augment

AUTOTUNE = tf.data.AUTOTUNE

class GetDataloader():
def __init__(self, args):
self.args = args

def dataloader(self, paths):
'''
Args:
paths: List of strings, where each string is path to the image.

Return:
dataloader: in-class dataloader
'''
# Consume dataframe
dataloader = tf.data.Dataset.from_tensor_slices(paths)

# Load the image
dataloader = (
dataloader
.map(self.parse_data, num_parallel_calls=AUTOTUNE)
)

if self.args.bool_config["do_cache"]:
dataloader = dataloader.cache()

# Add general stuff
dataloader = (
dataloader
.shuffle(self.args.dataset_config["batch_size"])
.batch(self.args.dataset_config["batch_size"])
.prefetch(AUTOTUNE)
)

return dataloader

def parse_data(self, path):
# Parse Image
image_string = tf.io.read_file(path)
image = tf.image.decode_jpeg(image_string, channels=3)
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
if self.args.bool_config["apply_resize"]:
image = tf.image.resize(image,
[self.args.dataset_config["image_height"],
self.args.dataset_config["image_width"]],
method='bicubic',
preserve_aspect_ratio=False)
image = tf.clip_by_value(image, 0.0, 1.0)

return image
5 changes: 5 additions & 0 deletions ssl_study/simclrv1/pretext_task/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .model import SimCLRv1Model

__all__ = [
'SimCLRv1Model'
]
38 changes: 38 additions & 0 deletions ssl_study/simclrv1/pretext_task/models/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import tensorflow as tf

class SimCLRv1Model():
def __init__(self, args):
self.args = args

def get_backbone(self):
"""Get backbone for the model."""
weights = None

if self.args.model_config["backbone"] == 'resnet50':
base_encoder = tf.keras.applications.ResNet50(include_top=False, weights=weights)
base_encoder.trainabe = True
else:
raise NotImplementedError("Not implemented for this backbone.")

return base_encoder

def get_model(self, hidden_1, hidden_2, hidden_3):
"""Get model."""
# Backbone
base_encoder = self.get_backbone()

# Stack layers
inputs = tf.keras.layers.Input(
(self.args.dataset_config["image_height"],
self.args.dataset_config["image_width"],
self.args.dataset_config["channels"]))

x = base_encoder(inputs, training=True)

projection_1 = tf.keras.layers.Dense(hidden_1)(x)
projection_1 = tf.keras.layers.Activation("relu")(projection_1)
projection_2 = tf.keras.layers.Dense(hidden_2)(projection_1)
projection_2 = tf.keras.layers.Activation("relu")(projection_2)
projection_3 = tf.keras.layers.Dense(hidden_3)(projection_2)

return tf.keras.models.Model(inputs, projection_3)
5 changes: 5 additions & 0 deletions ssl_study/simclrv1/pretext_task/pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .pipeline import SimCLRv1Pipeline

__all__ = [
'SimCLRv1Pipeline'
]
Loading