Skip to content

Latest commit

 

History

History
113 lines (71 loc) · 6.65 KB

README.md

File metadata and controls

113 lines (71 loc) · 6.65 KB

CLEFT: Language-Image Contrastive Learning with Efficient Large Language Model and Prompt Fine-Tuning

Yuexi Du, Brian Chang, Nicha C. Dvornek
Yale University

teaser

This is the official implementation of the paper "CLEFT: Language-Image Contrastive Learning with Efficient Large Language Model and Prompt Fine-Tuning" (Accepted by MICCAI 2024).

Abstract

Recent advancements in Contrastive Language-Image Pretraining (CLIP) [21] have demonstrated notable success in self-supervised representation learning across various tasks. However, the existing CLIP-like approaches often demand extensive GPU resources and prolonged training times due to the considerable size of the model and dataset, making them poor for medical applications, in which large datasets are not always common. Meanwhile, the language model prompts are mainly manually derived from labels tied to images, potentially overlooking the richness of information within training samples. We introduce a novel language-image Contrastive Learning method with an Efficient large language model and prompt Fine-Tuning (CLEFT) that harnesses the strengths of the extensive pre-trained language and visual models. Furthermore, we present an efficient strategy for learning context-based prompts that mitigates the gap between informative clinical diagnostic data and simple class labels. Our method demonstrates state-of-the-art performance on multiple chest X-ray and mammography datasets compared with various baselines. The proposed parameter efficient framework can reduce the total trainable model size by 39% and reduce the trainable language model to only 4% compared with the current BERT encoder.

Environment

install with the following steps:

conda env create -f environment.yml
# Manually install cosine annealing with warmup
pip install 'git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup'

You may need to manually install the flash-attn and xformers packages if error is encountered:

# Manually install flash attention
pip install flash-attn --no-build-isolation
# Manually install xformers
pip3 install -U xformers --index-url https://download.pytorch.org/whl/cu118

Our model can run without these two packages, but this may limit the training/inference speed and increase the GPU memory cost.

Dataset

Our experiment mainly uses the following three datasets. Please also follow the paper to pre-process the images. In general, we resize the images to have a short side of 518 and rename it with the suffix of _resized.

CheXpert-1.0

Download at here.

RSNA Dataset

Download at here.

EMBED Dataset

Acquire access from here.

Data Split

We provide the data split for the Chexpert dataset here. We use the same data split for the RSNA dataset.

Note that we cannot share the data split for the EMBED dataset publicly as access to this dataset needs approval. Please contact the author once you have access to the EMBED dataset, and we will share the data split with you.

Reproduce the experiment results:

We use wandb to log our experiment results, so you may want to configure your wandb first before reproducing the results.

To reproduce the results in the paper, you may follow the steps below:

Contrastive Pre-training

First, we do contrastive pre-training by running the following command:

python train.py  --batch_size 72 --learning_rate 4e-5 --experiment_name lora_linear_proj_learn_scale_pool_img_aug_swdcy --devices 4 --strategy 'ddp_find_unused_parameters_true' --llm_type gpt --precision bf16-true --peft lora --accumulate_grad_batches 1 --grad_ckpt --weight_decay 0.1 --warm_up 4000 --emb_dim 512 --max_steps 40000 --linear_proj --pool_feat

You may use a different PEFT method by changing the --peft parameter. Note that we use full BFloat16 precision during training, which is only supported by NVIDIA GPU with Ampere architecture or newer. You may use PFloat16 for older GPUs, but this may results in a different behavior.

Prompt Fine-tuning

You may then run the command below to conduct the context prompt fine-tuning.

python train.py  --batch_size 72 --learning_rate 1e-4  --experiment_name prompt_tuning_ft_vit_slr --devices 4 --strategy 'ddp_find_unused_parameters_true' --llm_type gpt --precision bf16-true --accumulate_grad_batches 1 --ctx_init caption --peft lora --max_steps 8000 --weight_decay 1e-3 --warm_up 100 --emb_dim 512 --linear_proj --pool_feat --pretrained_encoder <path_to_pretrained_ckpt> --grad_ckpt --min_lr 1e-5 --data_pct 1.0 --freeze_llm --sgd --prompt_ft --ctx_length 30

Model Evaluation

To evaluate the models, run

# CheXpert-5x200
python train.py  --batch_size 72 --learning_rate 4e-5 --experiment_name lora_linear_proj_learn_scale_pool_img_aug_swdcy --devices 4 --strategy 'ddp_find_unused_parameters_true' --llm_type gpt --precision bf16-true --peft lora --accumulate_grad_batches 1 --grad_ckpt --weight_decay 0.1 --warm_up 4000 --emb_dim 512 --max_steps 40000 --linear_proj --pool_feat --eval --five_cls --pretrained_model <path_to_pretrained_ckpt> 
# RSNA
python train.py  --batch_size 72 --learning_rate 4e-5 --experiment_name lora_linear_proj_learn_scale_pool_img_aug_swdcy --devices 4 --strategy 'ddp_find_unused_parameters_true' --llm_type gpt --precision bf16-true --peft lora --accumulate_grad_batches 1 --grad_ckpt --weight_decay 0.1 --warm_up 4000 --emb_dim 512 --max_steps 40000 --linear_proj --pool_feat --eval --five_cls --pretrained_model <path_to_pretrained_ckpt> --rsna 

Pre-trained Checkpoints

We here provide the link to the LoRA pre-trained model in the paper below:

Model Link
Contrastive Pre-train Google Drive
Prompt Fine-tune Google Drive

The "Contrastive Pre-trained" mode generally gives better stability while the "Prompt Fine-tuned" model may behave better in some tasks. Please take a look at the information from the original paper.

Reference

@article{du2024cleft,
  title={CLEFT: Language-Image Contrastive Learning with Efficient Large Language Model and Prompt Fine-Tuning},
  author={Du, Yuexi and Chang, Brian and Dvornek, Nicha C},
  journal={arXiv preprint arXiv:2407.21011},
  year={2024}
}