diff --git a/configs/simclrv1_pretext_config.py b/configs/simclrv1_pretext_config.py new file mode 100644 index 0000000..98569c2 --- /dev/null +++ b/configs/simclrv1_pretext_config.py @@ -0,0 +1,79 @@ +import os +import ml_collections + +def get_wandb_configs() -> ml_collections.ConfigDict: + configs = ml_collections.ConfigDict() + configs.project = "ssl-study" + configs.entity = "wandb_fc" + + return configs + +def get_dataset_configs() -> ml_collections.ConfigDict: + configs = ml_collections.ConfigDict() + configs.image_height = 224 #default - 224 + configs.image_width = 224 #default - 224 + configs.channels = 3 + configs.batch_size = 64 + configs.num_classes = 200 + + return configs + +def get_augment_configs() -> ml_collections.ConfigDict: + configs = ml_collections.ConfigDict() + configs.image_height = 224 #default - 224 + configs.image_width = 224 #default - 224 + configs.cropscale = (0.08, 1.0) + configs.cropratio = (0.75, 1.3333333333333333) + configs.jitterbrightness = 0.2 + configs.jittercontrast = 0.2 + configs.jittersaturation = 0.2 + configs.jitterhue = 0.2 + configs.gaussianblurlimit = (3, 7) + configs.gaussiansigmalimit = 0 + configs.alwaysapply = False + configs.probability = 0.5 + + return configs + +def get_bool_configs() -> ml_collections.ConfigDict: + configs = ml_collections.ConfigDict() + configs.apply_resize = True + configs.do_cache = False + configs.use_cosine_similarity = True + + return configs + +def get_model_configs() -> ml_collections.ConfigDict: + configs = ml_collections.ConfigDict() + configs.backbone = "resnet50" + configs.hidden1 = 256 + configs.hidden2 = 128 + configs.hidden3 = 50 + + return configs + +def get_train_configs() -> ml_collections.ConfigDict: + configs = ml_collections.ConfigDict() + configs.epochs = 30 + configs.temperature = 0.5 + configs.s = 1 + + return configs + +def get_learning_rate_configs() -> ml_collections.ConfigDict: + configs = ml_collections.ConfigDict() + configs.decay_steps = 1000 + configs.initial_learning_rate = 0.1 + +def get_config() -> ml_collections.ConfigDict: + config = ml_collections.ConfigDict() + config.seed = 0 + config.wandb_config = get_wandb_configs() + config.dataset_config = get_dataset_configs() + config.augmentation_config = get_augment_configs() + config.bool_config = get_bool_configs() + config.train_config = get_train_configs() + config.learning_rate_config = get_learning_rate_configs() + config.model_config = get_model_configs() + + return config \ No newline at end of file diff --git a/simclrv1_pretext.py b/simclrv1_pretext.py new file mode 100644 index 0000000..181e7a0 --- /dev/null +++ b/simclrv1_pretext.py @@ -0,0 +1,57 @@ +# General imports +import os +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +import glob +import wandb +from absl import app +from absl import flags +import numpy as np +import tensorflow as tf +from ml_collections.config_flags import config_flags + +# Import modules +from ssl_study.data import download_dataset, preprocess_dataframe +from ssl_study.simclrv1.pretext_task.data import GetDataloader +from ssl_study.simclrv1.pretext_task.models import SimCLRv1Model +from ssl_study.simclrv1.pretext_task.pipeline import SimCLRv1Pipeline + +FLAGS = flags.FLAGS +CONFIG = config_flags.DEFINE_config_file("config") + +def main(_): + with wandb.init( + entity=CONFIG.value.wandb_config.entity, + project=CONFIG.value.wandb_config.project, + job_type='simclrv1_pretext', + config=CONFIG.value.to_dict(), + ): + # Access all hyperparameter values through wandb.config + config = wandb.config + # Seed Everything + tf.random.set_seed(config.seed) + + # Load the dataframes + inclass_df = download_dataset('in-class', 'unlabelled-dataset') + + # Preprocess the DataFrames + inclass_paths = preprocess_dataframe(inclass_df, is_labelled=False) + + # Build dataloaders + dataset = GetDataloader(config) + inclassloader = dataset.dataloader(inclass_paths) + + # Model + tf.keras.backend.clear_session() + model = SimCLRv1Model(config).get_model(config.model_config["hidden1"], config.model_config["hidden2"], config.model_config["hidden3"]) + model.summary() + + # Build the pipeline + pipeline = SimCLRv1Pipeline(model, config) + optimizer = pipeline.get_optimizer + criterion = pipeline.get_criterion + + epoch_wise_loss, resnet_simclr = pipeline.train_simclr(model, inclassloader, optimizer, criterion, temperature=config.train_config["temperature"], epochs=config.train_config["epochs"]) + + +if __name__ == "__main__": + app.run(main) \ No newline at end of file diff --git a/ssl_study/data/__init__.py b/ssl_study/data/__init__.py index 64df8f1..60f4e6d 100644 --- a/ssl_study/data/__init__.py +++ b/ssl_study/data/__init__.py @@ -1,6 +1,6 @@ -from .dataset import download_dataset, preprocess_dataset +from .dataset import download_dataset, preprocess_dataframe from .dataloader import GetDataloader __all__ = [ - 'download_dataset', 'preprocess_dataset', 'GetDataloader' + 'download_dataset', 'preprocess_dataframe', 'GetDataloader' ] \ No newline at end of file diff --git a/ssl_study/data/dataset.py b/ssl_study/data/dataset.py index f47f036..4102abc 100644 --- a/ssl_study/data/dataset.py +++ b/ssl_study/data/dataset.py @@ -34,6 +34,10 @@ def download_dataset(dataset_name: str, data_df = pd.read_csv(save_at+'train.csv') elif dataset_name == 'val' and os.path.exists(save_at+'valid.csv'): data_df = pd.read_csv(save_at+'valid.csv') + elif dataset_name == 'in-class' and os.path.exists(save_at+'in-class.csv'): + data_df = pd.read_csv(save_at+'in-class.csv') + elif dataset_name == 'out-class' and os.path.exists(save_at+'out-class.csv'): + data_df = pd.read_csv(save_at+'out-class.csv') else: data_df = None print('Downloading dataset...') @@ -82,19 +86,23 @@ def download_dataset(dataset_name: str, if dataset_name == 'val' and not os.path.exists(save_at+'valid.csv'): data_df.to_csv(save_at+'valid.csv', index=False) - return data_df - + if dataset_name == 'in-class' and not os.path.exists(save_at+'in-class.csv'): + data_df.to_csv(save_at+'in-class.csv', index=False) -def preprocess_dataset(df): - # TODO: take care of df without labels. - # Remove unnecessary columns - df = df.drop(['image_id', 'width', 'height'], axis=1) - assert len(df.columns) == 2 + if dataset_name == 'out-class' and not os.path.exists(save_at+'out-class.csv'): + data_df.to_csv(save_at+'out-class.csv', index=False) - # Fix types - df[['label']] = df[['label']].apply(pd.to_numeric) + return data_df +def preprocess_dataframe(df, is_labelled=True): + df = df.drop(['image_id', 'width', 'height'], axis=1) image_paths = df.image_path.values - labels = df.label.values - return image_paths, labels \ No newline at end of file + if is_labelled: + assert len(df.columns) == 2 + # Fix types + df[['label']] = df[['label']].apply(pd.to_numeric) + labels = df.label.values + return image_paths, labels + else: + return image_paths \ No newline at end of file diff --git a/ssl_study/simclrv1/downstream_task/simclrv1_train.py b/ssl_study/simclrv1/downstream_task/simclrv1_train.py new file mode 100644 index 0000000..e69de29 diff --git a/ssl_study/simclrv1/pretext_task/data/__init__.py b/ssl_study/simclrv1/pretext_task/data/__init__.py new file mode 100644 index 0000000..1ae1309 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/data/__init__.py @@ -0,0 +1,6 @@ +from .data_aug import Augment +from .dataloader import GetDataloader + +__all__ = [ + 'Augment', 'GetDataloader' +] \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/data/data_aug.py b/ssl_study/simclrv1/pretext_task/data/data_aug.py new file mode 100644 index 0000000..73554b7 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/data/data_aug.py @@ -0,0 +1,55 @@ +from albumentations import augmentations +from albumentations.augmentations.transforms import ToGray +import numpy as np +import tensorflow as tf +from functools import partial +import albumentations as A + +AUTOTUNE = tf.data.AUTOTUNE + +class Augment(): + def __init__(self, args): + self.args = args + + def build_augmentation(self, image): + transform = A.Compose([ + A.RandomResizedCrop(self.args.augmentation_config["image_height"], + self.args.augmentation_config["image_width"], + self.args.augmentation_config["cropscale"], + self.args.augmentation_config["cropratio"], + self.args.augmentation_config["probability"]), + A.HorizontalFlip(self.args.augmentation_config["probability"]), + A.ColorJitter(self.args.augmentation_config["jitterbrightness"], + self.args.augmentation_config["jittercontrast"], + self.args.augmentation_config["jittersaturation"], + self.args.augmentation_config["jitterhue"], + self.args.augmentation_config["alwaysapply"], + self.args.augmentation_config["probability"]), + A.ToGray(self.args.augmentation_config["probability"]), + A.GaussianBlur(self.args.augmentation_config["gaussianblurlimit"], + self.args.augmentation_config["gaussiansigmalimit"], + self.args.augmentation_config["alwaysapply"], + self.args.augmentation_config["probability"]) + ]) + return transform + + def augmentation(self, image): + aug_img = tf.numpy_function(func=self.aug_fn, inp=[image], Tout=tf.float32) + aug_img.set_shape((self.args.augmentation_config["image_height"], + self.args.augmentation_config["image_width"], 3)) + + aug_img = tf.image.resize(aug_img, + [self.args.augmentation_config["image_height"], + self.args.augmentation_config["image_width"]], + method='bicubic', + preserve_aspect_ratio=False) + aug_img = tf.clip_by_value(aug_img, 0.0, 1.0) + + return aug_img + + def aug_fn(self, image): + data = {"image":image} + aug_data = self.build_augmentation(**data) + aug_img = aug_data["image"] + + return aug_img.astype(np.float32) \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/data/dataloader.py b/ssl_study/simclrv1/pretext_task/data/dataloader.py new file mode 100644 index 0000000..d7c72a8 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/data/dataloader.py @@ -0,0 +1,55 @@ +import numpy as np +import tensorflow as tf + +from .data_aug import Augment + +AUTOTUNE = tf.data.AUTOTUNE + +class GetDataloader(): + def __init__(self, args): + self.args = args + + def dataloader(self, paths): + ''' + Args: + paths: List of strings, where each string is path to the image. + + Return: + dataloader: in-class dataloader + ''' + # Consume dataframe + dataloader = tf.data.Dataset.from_tensor_slices(paths) + + # Load the image + dataloader = ( + dataloader + .map(self.parse_data, num_parallel_calls=AUTOTUNE) + ) + + if self.args.bool_config["do_cache"]: + dataloader = dataloader.cache() + + # Add general stuff + dataloader = ( + dataloader + .shuffle(self.args.dataset_config["batch_size"]) + .batch(self.args.dataset_config["batch_size"]) + .prefetch(AUTOTUNE) + ) + + return dataloader + + def parse_data(self, path): + # Parse Image + image_string = tf.io.read_file(path) + image = tf.image.decode_jpeg(image_string, channels=3) + image = tf.image.convert_image_dtype(image, dtype=tf.float32) + if self.args.bool_config["apply_resize"]: + image = tf.image.resize(image, + [self.args.dataset_config["image_height"], + self.args.dataset_config["image_width"]], + method='bicubic', + preserve_aspect_ratio=False) + image = tf.clip_by_value(image, 0.0, 1.0) + + return image \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/models/__init__.py b/ssl_study/simclrv1/pretext_task/models/__init__.py new file mode 100644 index 0000000..c89fb18 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/models/__init__.py @@ -0,0 +1,5 @@ +from .model import SimCLRv1Model + +__all__ = [ + 'SimCLRv1Model' +] \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/models/model.py b/ssl_study/simclrv1/pretext_task/models/model.py new file mode 100644 index 0000000..fe20640 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/models/model.py @@ -0,0 +1,38 @@ +import tensorflow as tf + +class SimCLRv1Model(): + def __init__(self, args): + self.args = args + + def get_backbone(self): + """Get backbone for the model.""" + weights = None + + if self.args.model_config["backbone"] == 'resnet50': + base_encoder = tf.keras.applications.ResNet50(include_top=False, weights=weights) + base_encoder.trainabe = True + else: + raise NotImplementedError("Not implemented for this backbone.") + + return base_encoder + + def get_model(self, hidden_1, hidden_2, hidden_3): + """Get model.""" + # Backbone + base_encoder = self.get_backbone() + + # Stack layers + inputs = tf.keras.layers.Input( + (self.args.dataset_config["image_height"], + self.args.dataset_config["image_width"], + self.args.dataset_config["channels"])) + + x = base_encoder(inputs, training=True) + + projection_1 = tf.keras.layers.Dense(hidden_1)(x) + projection_1 = tf.keras.layers.Activation("relu")(projection_1) + projection_2 = tf.keras.layers.Dense(hidden_2)(projection_1) + projection_2 = tf.keras.layers.Activation("relu")(projection_2) + projection_3 = tf.keras.layers.Dense(hidden_3)(projection_2) + + return tf.keras.models.Model(inputs, projection_3) \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/pipeline/__init__.py b/ssl_study/simclrv1/pretext_task/pipeline/__init__.py new file mode 100644 index 0000000..6e24ee2 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/pipeline/__init__.py @@ -0,0 +1,5 @@ +from .pipeline import SimCLRv1Pipeline + +__all__ = [ + 'SimCLRv1Pipeline' +] \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/pipeline/pipeline.py b/ssl_study/simclrv1/pretext_task/pipeline/pipeline.py new file mode 100644 index 0000000..0fa9e54 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/pipeline/pipeline.py @@ -0,0 +1,84 @@ +import os +import json +import tempfile +import numpy as np +from sklearn.metrics import accuracy_score +import wandb +from tqdm import tqdm +import tensorflow as tf + +from ssl_study.simclrv1.pretext_task.utils import _dot_simililarity_dim1 as sim_func_dim1, _dot_simililarity_dim2 as sim_func_dim2, get_negative_mask +from ssl_study.simclrv1.pretext_task.data import Augment + +class SimCLRv1Pipeline(): + def __init__(self, model, args): + self.args = args + self.model = model + + @tf.function + def train_step(self, xis, xjs, model, optimizer, criterion, temperature): + with tf.GradientTape() as tape: + zis = model(xis) + zjs = model(xjs) + + # normalize projection feature vectors + zis = tf.math.l2_normalize(zis, axis=1) + zjs = tf.math.l2_normalize(zjs, axis=1) + + l_pos = sim_func_dim1(zis, zjs) + l_pos = tf.reshape(l_pos, (self.args.dataset_config["batch_size"], 1)) + l_pos /= temperature + + negatives = tf.concat([zjs, zis], axis=0) + + loss = 0 + + for positives in [zis, zjs]: + l_neg = sim_func_dim2(positives, negatives) + + labels = tf.zeros(self.args.dataset_config["batch_size"], dtype=tf.int32) + + l_neg = tf.boolean_mask(l_neg, get_negative_mask(self.args.dataset_config["batch_size"])) + l_neg = tf.reshape(l_neg, (self.args.dataset_config["batch_size"], -1)) + l_neg /= temperature + + logits = tf.concat([l_pos, l_neg], axis=1) + loss += criterion(y_pred=logits, y_true=labels) + + loss = loss / (2 * self.args.dataset_config["batch_size"]) + + gradients = tape.gradient(loss, model.trainable_variables) + optimizer.apply_gradients(zip(gradients, model.trainable_variables)) + + return loss + + def train_simclr(self, model, dataset, optimizer, criterion, temperature=0.1, epochs=100): + step_wise_loss = [] + epoch_wise_loss = [] + + augment = Augment(self.args) + + for epoch in tqdm(range(epochs)): + for image_batch in dataset: + a = augment.augmentation(image_batch) + b = augment.augmentation(image_batch) + + loss = self.train_step(a, b, model, optimizer, criterion, temperature) + step_wise_loss.append(loss) + + epoch_wise_loss.append(np.mean(step_wise_loss)) + wandb.log({"nt_xentloss": np.mean(step_wise_loss)}) + + if epoch % 10 == 0: + print("epoch: {} loss: {:.3f}".format(epoch + 1, np.mean(step_wise_loss))) + + return epoch_wise_loss, model + + def get_criterion(self): + return tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM) + + def get_optimizer(self): + learning_rate = tf.keras.experimental.CosineDecay(initial_learning_rate= self.args.learning_rate_config['initial_learning_rate'], decay_steps=self.args.learning_rate_config['decay_steps']) + optimizer = tf.keras.optimizers.SGD(learning_rate) + + return optimizer \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/utils/__init__.py b/ssl_study/simclrv1/pretext_task/utils/__init__.py new file mode 100644 index 0000000..15d527f --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/utils/__init__.py @@ -0,0 +1,6 @@ +from .helpers import get_negative_mask +from .losses import _cosine_simililarity_dim1, _cosine_simililarity_dim2, _dot_simililarity_dim1, _dot_simililarity_dim2 + +__all__ = [ + 'get_negative_mask', '_cosine_simililarity_dim1', '_cosine_simililarity_dim2', '_dot_simililarity_dim1', '_dot_simililarity_dim2' +] \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/utils/helpers.py b/ssl_study/simclrv1/pretext_task/utils/helpers.py new file mode 100644 index 0000000..41fa14c --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/utils/helpers.py @@ -0,0 +1,12 @@ +import tensorflow as tf +import numpy as np + +def get_negative_mask(batch_size): + # return a mask that removes the similarity score of equal/similar images. + # this function ensures that only distinct pair of images get their similarity scores + # passed as negative examples + negative_mask = np.ones((batch_size, 2 * batch_size), dtype=bool) + for i in range(batch_size): + negative_mask[i, i] = 0 + negative_mask[i, i + batch_size] = 0 + return tf.constant(negative_mask) \ No newline at end of file diff --git a/ssl_study/simclrv1/pretext_task/utils/losses.py b/ssl_study/simclrv1/pretext_task/utils/losses.py new file mode 100644 index 0000000..764b280 --- /dev/null +++ b/ssl_study/simclrv1/pretext_task/utils/losses.py @@ -0,0 +1,33 @@ +import tensorflow as tf + +cosine_sim_1d = tf.keras.losses.CosineSimilarity(axis=1, reduction=tf.keras.losses.Reduction.NONE) +cosine_sim_2d = tf.keras.losses.CosineSimilarity(axis=2, reduction=tf.keras.losses.Reduction.NONE) + + +def _cosine_simililarity_dim1(x, y): + v = cosine_sim_1d(x, y) + return v + + +def _cosine_simililarity_dim2(x, y): + # x shape: (N, 1, C) + # y shape: (1, 2N, C) + # v shape: (N, 2N) + v = cosine_sim_2d(tf.expand_dims(x, 1), tf.expand_dims(y, 0)) + return v + + +def _dot_simililarity_dim1(x, y): + # x shape: (N, 1, C) + # y shape: (N, C, 1) + # v shape: (N, 1, 1) + v = tf.matmul(tf.expand_dims(x, 1), tf.expand_dims(y, 2)) + return v + + +def _dot_simililarity_dim2(x, y): + v = tf.tensordot(tf.expand_dims(x, 1), tf.expand_dims(tf.transpose(y), 0), axes=2) + # x shape: (N, 1, C) + # y shape: (1, C, 2N) + # v shape: (N, 2N) + return v \ No newline at end of file