Skip to content

Commit

Permalink
Merge branch 'dev' of https://github.com/tanthinhdt/imcap into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
tanthinhdt committed Nov 6, 2024
2 parents fdc4765 + 9add976 commit 7754af6
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 91 deletions.
3 changes: 0 additions & 3 deletions configs/data/flickr30k.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ comment_number: 0
padding: max_length
max_length: 64
truncation: True
image_mean: [0.48145466, 0.4578275, 0.40821073]
image_std: [0.26862954, 0.26130258, 0.27577711]
crop_size: 224

train_val_test_split: [0.8, 0.2] # train, (val,) test
batch_size: 16 # Needs to be divisible by the number of devices (e.g., if in a distributed setup)
Expand Down
27 changes: 10 additions & 17 deletions configs/experiment/flickr30k_blip-base_v2-0.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

defaults:
- override /data: flickr30k
- override /model: blip_base
- override /model: blip_base_pretrained
- override /callbacks: default
- override /trainer: default
- override /logger: wandb
Expand All @@ -26,21 +26,19 @@ model:
optimizer:
_target_: torch.optim.AdamW
_partial_: true
lr: 0.00001
lr: 0.000001
weight_decay: 0.05

scheduler:
_target_: torch.optim.lr_scheduler.CosineAnnealingLR
_partial_: true
T_max: ${trainer.max_epochs}
eta_min: 0.000001
eta_min: 0.0000001

net:
_target_: transformers.AutoModelForVision2Seq.from_config
config:
_target_: transformers.AutoConfig.from_pretrained
pretrained_model_name_or_path: Salesforce/blip-image-captioning-base
cache_dir: models/huggingface
_target_: transformers.AutoModelForVision2Seq.from_pretrained
pretrained_model_name_or_path: Salesforce/blip-image-captioning-base
cache_dir: models/huggingface

processor:
_target_: transformers.AutoProcessor.from_pretrained
Expand All @@ -54,15 +52,10 @@ model:
data:
use_all_comments: False
comment_number: 0
padding: max_length
max_length: 64
truncation: True
image_mean: [0.48145466, 0.4578275, 0.40821073]
image_std: [0.26862954, 0.26130258, 0.27577711]
crop_size: 224
train_val_test_split: [0.8, 0.2]
batch_size: 1
num_workers: 0
padding: longest
train_val_test_split: [0.7, 0.3]
batch_size: 16
num_workers: 8

logger:
wandb:
Expand Down
29 changes: 29 additions & 0 deletions configs/model/blip_base_pretrained.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
_target_: src.models.imcap_module.IMCAPLitModule

optimizer:
_target_: torch.optim.AdamW
_partial_: true
lr: 0.00001
weight_decay: 0.05

scheduler:
_target_: torch.optim.lr_scheduler.CosineAnnealingLR
_partial_: true
T_max: ${trainer.max_epochs}
eta_min: 0.000001

net:
_target_: transformers.AutoModelForVision2Seq.from_pretrained
pretrained_model_name_or_path: Salesforce/blip-image-captioning-base
cache_dir: models/huggingface

processor:
_target_: transformers.AutoProcessor.from_pretrained
pretrained_model_name_or_path: Salesforce/blip-image-captioning-base
cache_dir: models/huggingface

# compile model for faster training with pytorch 2.0
compile: false

# HuggingFace repo ID to push model
hf_repo_id: null
File renamed without changes.
4 changes: 2 additions & 2 deletions configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
defaults:
- _self_
- data: flickr30k
- model: blip_base
- model: blip_base_pretrained
- callbacks: default
- logger: null # set logger here or use command line (e.g. `python train.py logger=tensorboard`)
- trainer: default
Expand All @@ -15,7 +15,7 @@ defaults:

# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
- experiment: flickr30k_blip-base_v1-0
- experiment: flickr30k_blip-base_v2-0

# config for hyperparameter optimization
- hparams_search: null
Expand Down
82 changes: 19 additions & 63 deletions src/data/flickr30k_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
import os
import torch
import polars as pl
from pathlib import Path
from lightning import LightningDataModule
from typing import Any, Dict, Optional, Tuple
from torch.utils.data import DataLoader
from datasets import Dataset, DatasetDict
from transformers import AutoProcessor
from src.data.components import SelectText, Strip, Tokenize, LoadImage
from torchvision.transforms.v2 import (
Compose,
ToImage,
ToDtype,
Resize,
CenterCrop,
Normalize,
)
from PIL import Image


class Flickr30kDataModule(LightningDataModule):
Expand All @@ -28,13 +19,10 @@ def __init__(
processor: AutoProcessor,
data_dir: str = "data/flickr30k",
use_all_comments: bool = False,
comment_number: int = None,
comment_number: int = 0,
padding: str = "max_length",
max_length: int = 128,
truncation: bool = True,
crop_size: int = 224,
image_mean: Tuple[float] = (0.48145466, 0.4578275, 0.40821073),
image_std: Tuple[float] = (0.26862954, 0.26130258, 0.27577711),
train_val_test_split: Tuple[float] = (0.8, 0.1, 0.1),
batch_size: int = 64,
num_workers: int = 0,
Expand All @@ -59,12 +47,6 @@ def __init__(
The maximum length of the sequence, by default 128.
truncation : bool, optional
Whether to truncate the sequence, by default True.
image_mean : Tuple[float], optional
The mean values for image normalization, by default (0.48145466, 0.4578275, 0.40821073).
image_std : Tuple[float], optional
The standard deviation values for image normalization, by default (0.26862954, 0.26130258, 0.27577711).
crop_size : int, optional
The size of the crop, by default 224.
train_val_test_split : Tuple[float], optional
The split ratio for the train, validation, and test sets, by default (0.8, 0.1, 0.1).
batch_size : int, optional
Expand All @@ -83,28 +65,7 @@ def __init__(
self.save_hyperparameters(logger=False)

# data transformations
self.text_transforms = Compose(
[
SelectText(index=comment_number),
Strip(),
Tokenize(
processor=processor,
max_length=max_length,
padding=padding,
truncation=truncation,
),
]
)
self.vision_transforms = Compose(
[
LoadImage(image_dir=Path(data_dir) / "flickr30k_images"),
ToImage(),
ToDtype(dtype=torch.float32),
Resize(size=crop_size),
CenterCrop(size=crop_size),
Normalize(mean=image_mean, std=image_std),
]
)
self.processor = processor

self.dataset: Optional[DatasetDict] = None
self.num_examples: int = 0
Expand Down Expand Up @@ -132,26 +93,21 @@ def setup(self, stage: Optional[str] = None) -> None:
The stage to setup, by default
"""
def transform(batch):
transformed_batch = {
"pixel_values": [],
"input_ids": [],
"attention_mask": [],
}

for image_name in batch["image_name"]:
outputs = self.vision_transforms(image_name)
transformed_batch["pixel_values"].append(outputs.squeeze(0))

for comment in batch["comment"]:
outputs = self.text_transforms(comment)
for key in outputs:
transformed_batch[key].append(outputs[key].squeeze(0))

for key in transformed_batch:
transformed_batch[key] = torch.stack(transformed_batch[key])
transformed_batch["labels"] = transformed_batch["input_ids"].clone()

return transformed_batch
images = [
Image.open(Path(self.hparams.data_dir) / f"flickr30k_images/{image_name}")
for image_name in batch["image_name"]
]
texts = [comment.strip() for comment in batch["comment"]]
batch = self.processor(
images=images,
text=texts,
padding=self.hparams.padding,
max_length=self.hparams.max_length,
truncation=self.hparams.truncation,
return_tensors="pt",
)
batch.update({"labels": batch["input_ids"]})
return batch

# Divide batch size by the number of devices.
if self.trainer is not None:
Expand All @@ -174,7 +130,7 @@ def transform(batch):
},
)
if not self.hparams.use_all_comments:
df = df.group_by("image_name", maintain_order=True).all()
df = df.filter(pl.col("comment_number") == self.hparams.comment_number)
self.num_examples = len(df)
dataset = Dataset.from_polars(df)
dataset = dataset.shuffle(seed=int(os.environ.get("PL_GLOBAL_SEED", 42)))
Expand Down
7 changes: 1 addition & 6 deletions tests/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,10 @@ def test_flickr30k_datamodule(batch_size: int, train_val_test_split: tuple) -> N
)
data_dir = "data/flickr30k"
use_all_comments = False
comment_number = None
comment_number = 0
padding = "max_length"
max_length = 128
truncation = True
image_mean = (0.48145466, 0.4578275, 0.40821073)
image_std = (0.26862954, 0.26130258, 0.27577711)
crop_size = 384

assert Path(data_dir, "results.csv").exists(), \
Expand All @@ -57,9 +55,6 @@ def test_flickr30k_datamodule(batch_size: int, train_val_test_split: tuple) -> N
padding=padding,
max_length=max_length,
truncation=truncation,
image_mean=image_mean,
image_std=image_std,
crop_size=crop_size,
train_val_test_split=train_val_test_split,
batch_size=batch_size,
)
Expand Down

0 comments on commit 7754af6

Please sign in to comment.