From f08b6f9d381b9f3918ab3066b1437ffd20d5a04d Mon Sep 17 00:00:00 2001 From: tanthinhdt Date: Wed, 6 Nov 2024 09:42:22 +0700 Subject: [PATCH 1/7] update(configs/model): update config of BLIP Base models --- configs/model/blip_base_pretrained.yaml | 29 +++++++++++++++++++ ...{blip_base.yaml => blip_base_vanilla.yaml} | 0 2 files changed, 29 insertions(+) create mode 100644 configs/model/blip_base_pretrained.yaml rename configs/model/{blip_base.yaml => blip_base_vanilla.yaml} (100%) diff --git a/configs/model/blip_base_pretrained.yaml b/configs/model/blip_base_pretrained.yaml new file mode 100644 index 0000000..e838414 --- /dev/null +++ b/configs/model/blip_base_pretrained.yaml @@ -0,0 +1,29 @@ +_target_: src.models.imcap_module.IMCAPLitModule + +optimizer: + _target_: torch.optim.AdamW + _partial_: true + lr: 0.00001 + weight_decay: 0.05 + +scheduler: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + _partial_: true + T_max: ${trainer.max_epochs} + eta_min: 0.000001 + +net: + _target_: transformers.AutoModelForVision2Seq.from_pretrained + pretrained_model_name_or_path: Salesforce/blip-image-captioning-base + cache_dir: models/huggingface + +processor: + _target_: transformers.AutoProcessor.from_pretrained + pretrained_model_name_or_path: Salesforce/blip-image-captioning-base + cache_dir: models/huggingface + +# compile model for faster training with pytorch 2.0 +compile: false + +# HuggingFace repo ID to push model +hf_repo_id: null diff --git a/configs/model/blip_base.yaml b/configs/model/blip_base_vanilla.yaml similarity index 100% rename from configs/model/blip_base.yaml rename to configs/model/blip_base_vanilla.yaml From 133e983688906f1e37ab9ac01b6f8736a1babb0f Mon Sep 17 00:00:00 2001 From: tanthinhdt Date: Wed, 6 Nov 2024 09:43:19 +0700 Subject: [PATCH 2/7] update(configs/train): update training config --- configs/train.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/train.yaml b/configs/train.yaml index b146b95..64ae111 100644 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -5,7 +5,7 @@ defaults: - _self_ - data: flickr30k - - model: blip_base + - model: blip_base_pretrained - callbacks: default - logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`) - trainer: default @@ -15,7 +15,7 @@ defaults: # experiment configs allow for version control of specific hyperparameters # e.g. best hyperparameters for given model and datamodule - - experiment: flickr30k_blip-base_v1-0 + - experiment: flickr30k_blip-base_v2-0 # config for hyperparameter optimization - hparams_search: null From 99e1e3c006a68c05909032a36ad3c87fcf826e98 Mon Sep 17 00:00:00 2001 From: tanthinhdt Date: Wed, 6 Nov 2024 09:44:14 +0700 Subject: [PATCH 3/7] update(src/data/flickr30k_datamodule): update transforms --- src/data/flickr30k_datamodule.py | 82 ++++++++------------------------ 1 file changed, 19 insertions(+), 63 deletions(-) diff --git a/src/data/flickr30k_datamodule.py b/src/data/flickr30k_datamodule.py index 0207781..ca7c16c 100644 --- a/src/data/flickr30k_datamodule.py +++ b/src/data/flickr30k_datamodule.py @@ -1,5 +1,4 @@ import os -import torch import polars as pl from pathlib import Path from lightning import LightningDataModule @@ -7,15 +6,7 @@ from torch.utils.data import DataLoader from datasets import Dataset, DatasetDict from transformers import AutoProcessor -from src.data.components import SelectText, Strip, Tokenize, LoadImage -from torchvision.transforms.v2 import ( - Compose, - ToImage, - ToDtype, - Resize, - CenterCrop, - Normalize, -) +from PIL import Image class Flickr30kDataModule(LightningDataModule): @@ -28,13 +19,10 @@ def __init__( processor: AutoProcessor, data_dir: str = "data/flickr30k", use_all_comments: bool = False, - comment_number: int = None, + comment_number: int = 0, padding: str = "max_length", max_length: int = 128, truncation: bool = True, - crop_size: int = 224, - image_mean: Tuple[float] = (0.48145466, 0.4578275, 0.40821073), - image_std: Tuple[float] = (0.26862954, 0.26130258, 0.27577711), train_val_test_split: Tuple[float] = (0.8, 0.1, 0.1), batch_size: int = 64, num_workers: int = 0, @@ -59,12 +47,6 @@ def __init__( The maximum length of the sequence, by default 128. truncation : bool, optional Whether to truncate the sequence, by default True. - image_mean : Tuple[float], optional - The mean values for image normalization, by default (0.48145466, 0.4578275, 0.40821073). - image_std : Tuple[float], optional - The standard deviation values for image normalization, by default (0.26862954, 0.26130258, 0.27577711). - crop_size : int, optional - The size of the crop, by default 224. train_val_test_split : Tuple[float], optional The split ratio for the train, validation, and test sets, by default (0.8, 0.1, 0.1). batch_size : int, optional @@ -83,28 +65,7 @@ def __init__( self.save_hyperparameters(logger=False) # data transformations - self.text_transforms = Compose( - [ - SelectText(index=comment_number), - Strip(), - Tokenize( - processor=processor, - max_length=max_length, - padding=padding, - truncation=truncation, - ), - ] - ) - self.vision_transforms = Compose( - [ - LoadImage(image_dir=Path(data_dir) / "flickr30k_images"), - ToImage(), - ToDtype(dtype=torch.float32), - Resize(size=crop_size), - CenterCrop(size=crop_size), - Normalize(mean=image_mean, std=image_std), - ] - ) + self.processor = processor self.dataset: Optional[DatasetDict] = None self.num_examples: int = 0 @@ -132,26 +93,21 @@ def setup(self, stage: Optional[str] = None) -> None: The stage to setup, by default """ def transform(batch): - transformed_batch = { - "pixel_values": [], - "input_ids": [], - "attention_mask": [], - } - - for image_name in batch["image_name"]: - outputs = self.vision_transforms(image_name) - transformed_batch["pixel_values"].append(outputs.squeeze(0)) - - for comment in batch["comment"]: - outputs = self.text_transforms(comment) - for key in outputs: - transformed_batch[key].append(outputs[key].squeeze(0)) - - for key in transformed_batch: - transformed_batch[key] = torch.stack(transformed_batch[key]) - transformed_batch["labels"] = transformed_batch["input_ids"].clone() - - return transformed_batch + images = [ + Image.open(Path(self.hparams.data_dir) / f"flickr30k_images/{image_name}") + for image_name in batch["image_name"] + ] + texts = [comment.strip() for comment in batch["comment"]] + batch = self.processor( + images=images, + text=texts, + padding=self.hparams.padding, + max_length=self.hparams.max_length, + truncation=self.hparams.truncation, + return_tensors="pt", + ) + batch.update({"labels": batch["input_ids"]}) + return batch # Divide batch size by the number of devices. if self.trainer is not None: @@ -174,7 +130,7 @@ def transform(batch): }, ) if not self.hparams.use_all_comments: - df = df.group_by("image_name", maintain_order=True).all() + df = df.filter(pl.col("comment_number") == self.hparams.comment_number) self.num_examples = len(df) dataset = Dataset.from_polars(df) dataset = dataset.shuffle(seed=int(os.environ.get("PL_GLOBAL_SEED", 42))) From 15edc01e308cf3d0c27aaab9bb3733d8b1b13ab0 Mon Sep 17 00:00:00 2001 From: tanthinhdt Date: Wed, 6 Nov 2024 09:45:05 +0700 Subject: [PATCH 4/7] update(tests/test_datamodules): update test for flickr30k --- tests/test_datamodules.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/test_datamodules.py b/tests/test_datamodules.py index c072ac3..7cd4be7 100644 --- a/tests/test_datamodules.py +++ b/tests/test_datamodules.py @@ -36,12 +36,10 @@ def test_flickr30k_datamodule(batch_size: int, train_val_test_split: tuple) -> N ) data_dir = "data/flickr30k" use_all_comments = False - comment_number = None + comment_number = 0 padding = "max_length" max_length = 128 truncation = True - image_mean = (0.48145466, 0.4578275, 0.40821073) - image_std = (0.26862954, 0.26130258, 0.27577711) crop_size = 384 assert Path(data_dir, "results.csv").exists(), \ @@ -57,9 +55,6 @@ def test_flickr30k_datamodule(batch_size: int, train_val_test_split: tuple) -> N padding=padding, max_length=max_length, truncation=truncation, - image_mean=image_mean, - image_std=image_std, - crop_size=crop_size, train_val_test_split=train_val_test_split, batch_size=batch_size, ) From 5c8f8a27266541b127d7f89fd385ec1de025c636 Mon Sep 17 00:00:00 2001 From: tanthinhdt Date: Wed, 6 Nov 2024 09:46:23 +0700 Subject: [PATCH 5/7] update(configs/experiment/flickr30k_blip-base_v2-0): update config --- .../experiment/flickr30k_blip-base_v2-0.yaml | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/configs/experiment/flickr30k_blip-base_v2-0.yaml b/configs/experiment/flickr30k_blip-base_v2-0.yaml index 92f78d3..7397a42 100644 --- a/configs/experiment/flickr30k_blip-base_v2-0.yaml +++ b/configs/experiment/flickr30k_blip-base_v2-0.yaml @@ -5,7 +5,7 @@ defaults: - override /data: flickr30k - - override /model: blip_base + - override /model: blip_base_pretrained - override /callbacks: default - override /trainer: default - override /logger: wandb @@ -26,21 +26,19 @@ model: optimizer: _target_: torch.optim.AdamW _partial_: true - lr: 0.00001 + lr: 0.000001 weight_decay: 0.05 scheduler: _target_: torch.optim.lr_scheduler.CosineAnnealingLR _partial_: true T_max: ${trainer.max_epochs} - eta_min: 0.000001 + eta_min: 0.0000001 net: - _target_: transformers.AutoModelForVision2Seq.from_config - config: - _target_: transformers.AutoConfig.from_pretrained - pretrained_model_name_or_path: Salesforce/blip-image-captioning-base - cache_dir: models/huggingface + _target_: transformers.AutoModelForVision2Seq.from_pretrained + pretrained_model_name_or_path: Salesforce/blip-image-captioning-base + cache_dir: models/huggingface processor: _target_: transformers.AutoProcessor.from_pretrained @@ -54,15 +52,12 @@ model: data: use_all_comments: False comment_number: 0 - padding: max_length - max_length: 64 - truncation: True - image_mean: [0.48145466, 0.4578275, 0.40821073] - image_std: [0.26862954, 0.26130258, 0.27577711] - crop_size: 224 - train_val_test_split: [0.8, 0.2] - batch_size: 1 - num_workers: 0 + padding: longest + # max_length: 64 + # truncation: True + train_val_test_split: [0.7, 0.3] + batch_size: 16 + num_workers: 8 logger: wandb: From b4c48282fe3dd072c7b7db4f86d43ce5a9d0c862 Mon Sep 17 00:00:00 2001 From: tanthinhdt Date: Wed, 6 Nov 2024 09:47:29 +0700 Subject: [PATCH 6/7] update(configs/experiment/flickr30k_blip-base_v2-0): update config --- configs/experiment/flickr30k_blip-base_v2-0.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/configs/experiment/flickr30k_blip-base_v2-0.yaml b/configs/experiment/flickr30k_blip-base_v2-0.yaml index 7397a42..60d37de 100644 --- a/configs/experiment/flickr30k_blip-base_v2-0.yaml +++ b/configs/experiment/flickr30k_blip-base_v2-0.yaml @@ -53,8 +53,6 @@ data: use_all_comments: False comment_number: 0 padding: longest - # max_length: 64 - # truncation: True train_val_test_split: [0.7, 0.3] batch_size: 16 num_workers: 8 From eb7aefe872c754c9ec6a2a8efe3bcf16537d2e5b Mon Sep 17 00:00:00 2001 From: tanthinhdt Date: Wed, 6 Nov 2024 09:47:48 +0700 Subject: [PATCH 7/7] update(configs/data/flickr30k): update transforms --- configs/data/flickr30k.yaml | 3 --- 1 file changed, 3 deletions(-) diff --git a/configs/data/flickr30k.yaml b/configs/data/flickr30k.yaml index 62ca72b..21d6271 100644 --- a/configs/data/flickr30k.yaml +++ b/configs/data/flickr30k.yaml @@ -7,9 +7,6 @@ comment_number: 0 padding: max_length max_length: 64 truncation: True -image_mean: [0.48145466, 0.4578275, 0.40821073] -image_std: [0.26862954, 0.26130258, 0.27577711] -crop_size: 224 train_val_test_split: [0.8, 0.2] # train, (val,) test batch_size: 16 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)