-
Notifications
You must be signed in to change notification settings - Fork 441
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Opensource REVEAL (https://arxiv.org/abs/2212.05221)
PiperOrigin-RevId: 560033039
- Loading branch information
Scenic Authors
committed
Aug 25, 2023
1 parent
3806d06
commit 6447413
Showing
43 changed files
with
8,409 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Repository for REVEAL: Retrieval-Augmented Visual-Language Pre-Training with Multi-Source Multimodal Knowledge Memory | ||
![REVEAL is an End-to-End Retrieval-Augmented VLM](data/vivit.png) | ||
|
||
|
||
### [Project Page](https://reveal-cvpr.github.io/) | [arXiv](https://arxiv.org/abs/2212.05221) | ||
|
||
|
||
## What is REVEAL? | ||
|
||
We propose an end-to-end Retrieval-Augmented Visual Language Model (REVEAL) that learns to encode world knowledge into a large-scale memory, and to retrieve from it to answer knowledge-intensive queries | ||
|
||
REVEAL consists of four key components: the memory, the encoder, the retriever and the generator. The large-scale memory encodes various sources of multimodal world knowledge (e.g. image-text pairs, question answering pairs, knowledge graph triplets, etc) via a unified encoder. The retriever finds the most relevant knowledge entries in the memory, and the generator fuses the retrieved knowledge with the input query to produce the output. A key novelty in our approach is that the memory, encoder, retriever and generator are all pre-trained end-to-end on a massive amount of data. Furthermore, our approach can use a diverse set of multimodal knowledge sources, which is shown to result in significant gains. We show that REVEAL achieves state-of-the-art results on visual question answering and image captioning. | ||
|
||
More details can be found in the [paper](https://arxiv.org/abs/2212.05221) published at CVPR 2023 (Highlight). | ||
|
||
|
||
## Citation | ||
|
||
If you use REVEAL, please use the following BibTeX entry. | ||
|
||
``` | ||
@inproceedings{reveal, | ||
title={{REVEAL:} Retrieval-Augmented Visual-Language Pre-Training with Multi-Source Multimodal Knowledge Memory}, | ||
author={Ziniu Hu and | ||
Ahmet Iscen and | ||
Chen Sun and | ||
Zirui Wang and | ||
Kai{-}Wei Chang and | ||
Yizhou Sun and | ||
Cordelia Schmid and | ||
David A. Ross and | ||
Alireza Fathi}, | ||
booktitle={CVPR}, | ||
year={2023} | ||
} | ||
``` |
Empty file.
Empty file.
127 changes: 127 additions & 0 deletions
127
scenic/projects/knowledge_visual_language/configs/finetune_okvqa_base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
r"""WIT Retrieval + Captioning Pre-Training.""" | ||
|
||
import ml_collections | ||
|
||
TRAIN_DATA_SIZE = 10000 | ||
|
||
|
||
def get_config() -> ml_collections.ConfigDict: | ||
"""Returns the base experiment configuration.""" | ||
config = ml_collections.ConfigDict() | ||
config.experiment_name = 'image_caption_debug' | ||
|
||
config.optimizer = 'adafactor' | ||
n_device = 128 | ||
batch_size = 6 * 2 * n_device | ||
config.optimizer_configs = ml_collections.ConfigDict() | ||
config.optimizer_configs.momentum = None | ||
# config.optimizer_configs.momentum = 0.9 | ||
# config.optimizer_configs.dtype_momentum = 'bfloat16' | ||
config.optimizer_configs.weight_decay_rate = 0 | ||
config.optimizer_configs.clipping_threshold = 10.0 | ||
config.optimizer_configs.skip_scale_and_bias_regularization = False | ||
|
||
config.frozen_patterns = [] | ||
config.not_frozen_patterns = [ | ||
('value_perceiver/.*', 0.1), | ||
# ('text_encoder/.*', 0.05), | ||
# ('img_encoder/.*', 0.05), | ||
# ('shared_token_embedder/.*', 0.02), | ||
('query_head/.*', 0.3), | ||
('out_decoder/.*', 1.0), | ||
('key_head/.*', 0.3), | ||
('head_out/.*', 0.2), | ||
('fusion_encoder/.*', 0.5), | ||
('att_transform/.*', 0.3), | ||
('dataset_gate/.*', 0.5), | ||
] | ||
|
||
config.grad_clip_configs = ml_collections.ConfigDict() | ||
config.grad_clip_configs.clip_method = 'clip_by_global_norm' | ||
config.grad_clip_configs.clip_value = 1.0 | ||
|
||
config.kb_dataset_names = ['wit_table', 'cc12m_table', 'vqa_table'] | ||
config.kb_dataset_configs = [{}, {}, {}] | ||
|
||
config.batch_size = batch_size | ||
config.eval_batch_size = batch_size | ||
config.rng_seed = 0 | ||
config.update_num = True | ||
config.num_training_epochs = 1000 | ||
config.data_dtype_str = 'bfloat16' | ||
# Model | ||
config.model_name = 'knowledge_fid' | ||
config.model = ml_collections.ConfigDict() | ||
config.model.image_model = 'vit' | ||
config.model.t5_name = 't5_1_1_base' | ||
# ['t5_1_1_small', 't5_1_1_base', 't5_1_1_large', 't5_1_1_xl', 't5_1_1_xxl'] | ||
config.model.num_fusion_layers = 6 | ||
config.model.n_compressed_tokens = 32 | ||
config.model.key_dim = 512 | ||
config.model.dropout_rate = 0.1 | ||
config.model.temperature = 0.2 | ||
config.model.retr_k = 50 | ||
config.model.retr_data_ratio = 0.1 | ||
config.model.label_smoothing = 1e-2 | ||
config.model.vit_name = 'B/16' | ||
config.model.vit_model_path = 'JFT3b-B/16' | ||
# [JFT3b-B/32, JFT3b-B/16, JFT3b-L/16, JFT3b-g/14, JFT3b-G/14] | ||
config.model.t5_frozen_base = True | ||
config.model.vit_num_frozen_layers = 1 / 2 | ||
config.model.retrieve_local = False | ||
config.model.use_psudo_retr = True | ||
config.model.disentangle = False | ||
config.model.gap = False | ||
config.model.retrieval_ratio = 0.2 | ||
config.model.n_knowledge_source = len(config.kb_dataset_names) | ||
config.model.qa = True | ||
config.frozen_memory = False | ||
|
||
config.vocab_size = 32120 | ||
config.autoregressive_decoding = ml_collections.ConfigDict() | ||
config.autoregressive_decoding.num_decodes = 1 | ||
config.autoregressive_decoding.beam_search = False | ||
|
||
# Dataset. | ||
config.dataset_name = 'okvqa' | ||
config.dataset_configs = ml_collections.ConfigDict() | ||
|
||
# Learning rate. | ||
config.num_train_examples = TRAIN_DATA_SIZE | ||
steps_per_epoch = TRAIN_DATA_SIZE // config.batch_size | ||
config.lr_configs = ml_collections.ConfigDict() | ||
config.lr_configs.total_steps = int( | ||
config.num_training_epochs * steps_per_epoch | ||
) | ||
config.lr_configs.learning_rate_schedule = 'compound' | ||
config.lr_configs.factors = 'constant * rsqrt_decay * linear_warmup' | ||
config.lr_configs.warmup_steps = 2000 | ||
config.lr_configs.timescale = 5000 | ||
# config.lr_configs.steps_per_cycle = config.lr_configs.total_steps | ||
config.lr_configs.base_learning_rate = 1e-5 | ||
config.lr_configs.end_learning_rate = 1e-6 | ||
|
||
# Logging. | ||
config.log_summary_steps = 100 | ||
config.log_eval_steps = 500 | ||
config.checkpoint_steps = 1000 | ||
config.write_summary = True | ||
config.xprof = True # Profile using xprof | ||
config.checkpoint = True # Do checkpointing. | ||
config.debug_train = False # Debug mode during training. | ||
config.debug_eval = False # Debug mode during eval. | ||
|
||
# Initalisation configs | ||
config.init_from = ml_collections.ConfigDict() | ||
# Initializing from a vidcap model. | ||
config.init_from.xm = None | ||
config.init_from.xm = (54461417, 1) # compress=32 | ||
# config.init_from.xm = (51839031, 1) # compress=32 | ||
# config.init_from.xm = (46499437, 1) # compress=64 | ||
# config.init_from.resume = (50154314, 1) # compress=32 | ||
config.init_from.only_params = False | ||
config.init_from.load_key_encoder = False | ||
config.init_from.encoder = ml_collections.ConfigDict() | ||
config.init_from.encoder.init_from_vit = False | ||
config.init_from.encoder.checkpoint_path = None | ||
return config |
119 changes: 119 additions & 0 deletions
119
scenic/projects/knowledge_visual_language/configs/wit_memory_G.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
r"""WIT Retrieval + Captioning Pre-Training.""" | ||
|
||
import ml_collections | ||
|
||
TRAIN_DATA_SIZE = 1_000_000_000 | ||
|
||
|
||
def get_config() -> ml_collections.ConfigDict: | ||
"""Returns the base experiment configuration.""" | ||
config = ml_collections.ConfigDict() | ||
config.experiment_name = 'image_caption_debug' | ||
|
||
config.optimizer = 'adafactor' | ||
n_device = 256 | ||
batch_size = 4 * 2 * n_device | ||
config.optimizer_configs = ml_collections.ConfigDict() | ||
config.optimizer_configs.momentum = None | ||
# config.optimizer_configs.momentum = 0.9 | ||
# config.optimizer_configs.dtype_momentum = 'bfloat16' | ||
config.optimizer_configs.weight_decay_rate = 2e-3 | ||
config.optimizer_configs.clipping_threshold = 5.0 | ||
config.optimizer_configs.skip_scale_and_bias_regularization = True | ||
|
||
config.frozen_patterns = [] | ||
config.not_frozen_patterns = [('value_perceiver/.*', 0.3), | ||
('text_encoder/.*', 0.1), | ||
('img_encoder/.*', 0.1), | ||
('shared_token_embedder/.*', 0.1), | ||
('query_head/.*', 0.2), ('out_decoder/.*', 1), | ||
('key_head/.*', 0.2), ('head_out/.*', 0.2), | ||
('fusion_encoder/.*', 0.5), | ||
('att_transform/.*', 0.3), | ||
('dataset_gate/.*', 0.5)] | ||
|
||
config.grad_clip_configs = ml_collections.ConfigDict() | ||
config.grad_clip_configs.clip_method = 'clip_by_global_norm' | ||
config.grad_clip_configs.clip_value = 1.0 | ||
|
||
config.kb_dataset_names = ['wit_table', 'cc12m_table', 'vqa_table'] | ||
config.kb_dataset_configs = [{}, {}, {}] | ||
|
||
config.batch_size = batch_size | ||
config.eval_batch_size = batch_size | ||
config.rng_seed = 0 | ||
config.update_num = False | ||
config.num_training_epochs = 2 | ||
config.data_dtype_str = 'bfloat16' | ||
# Model | ||
config.model_name = 'knowledge_fid' | ||
config.model = ml_collections.ConfigDict() | ||
config.model.image_model = 'vit' | ||
config.model.t5_name = 't5_1_1_large' | ||
# ['t5_1_1_small', 't5_1_1_base', 't5_1_1_large', 't5_1_1_xl', 't5_1_1_xxl'] | ||
config.model.num_fusion_layers = 8 | ||
config.model.n_compressed_tokens = 32 | ||
config.model.key_dim = 512 | ||
config.model.dropout_rate = 0.0 | ||
config.model.temperature = 0.2 | ||
config.model.retr_k = 10 | ||
config.model.retr_data_ratio = 0.2 | ||
config.model.label_smoothing = 1e-2 | ||
config.model.vit_name = 'G/14' | ||
config.model.vit_model_path = 'JFT3b-G/14' | ||
# [JFT3b-B/32, JFT3b-B/16, JFT3b-L/16, JFT3b-g/14, JFT3b-G/14] | ||
config.model.t5_frozen_base = False | ||
config.model.vit_num_frozen_layers = 5 / 6 | ||
config.model.retrieve_local = False | ||
config.model.use_psudo_retr = True | ||
config.model.disentangle = True | ||
config.model.gap = True | ||
config.model.retrieval_ratio = 1e-2 | ||
config.model.n_knowledge_source = len(config.kb_dataset_names) | ||
config.model.qa = False | ||
config.frozen_memory = False | ||
|
||
config.vocab_size = 32120 | ||
config.autoregressive_decoding = ml_collections.ConfigDict() | ||
config.autoregressive_decoding.num_decodes = 1 | ||
config.autoregressive_decoding.beam_search = False | ||
# Dataset. | ||
config.dataset_name = 'web_image_text_generation' | ||
config.dataset_configs = ml_collections.ConfigDict() | ||
|
||
# Learning rate. | ||
config.num_train_examples = TRAIN_DATA_SIZE | ||
steps_per_epoch = TRAIN_DATA_SIZE // config.batch_size | ||
config.lr_configs = ml_collections.ConfigDict() | ||
config.lr_configs.total_steps = int(config.num_training_epochs * | ||
steps_per_epoch) | ||
config.lr_configs.learning_rate_schedule = 'compound' | ||
config.lr_configs.factors = 'constant * rsqrt_decay * linear_warmup' | ||
config.lr_configs.warmup_steps = 20000 | ||
config.lr_configs.timescale = 10000 | ||
# config.lr_configs.steps_per_cycle = config.lr_configs.total_steps | ||
config.lr_configs.base_learning_rate = 1e-3 | ||
config.lr_configs.end_learning_rate = 1e-6 | ||
|
||
# Logging. | ||
config.log_summary_steps = 100 | ||
config.log_eval_steps = 1000 | ||
config.checkpoint_steps = 5000 | ||
config.write_summary = True | ||
config.xprof = True # Profile using xprof | ||
config.checkpoint = True # Do checkpointing. | ||
config.debug_train = False # Debug mode during training. | ||
config.debug_eval = False # Debug mode during eval. | ||
|
||
# Initalisation configs | ||
config.init_from = ml_collections.ConfigDict() | ||
# Initializing from a vidcap model. | ||
# config.init_from.xm = None | ||
# config.init_from.xm = (46234383, 1) # compress=32 | ||
config.init_from.resume = (49645684, 1) # compress=64 | ||
config.init_from.only_params = False | ||
config.init_from.load_key_encoder = False | ||
config.init_from.encoder = ml_collections.ConfigDict() | ||
config.init_from.encoder.init_from_vit = False | ||
config.init_from.encoder.checkpoint_path = None | ||
return config |
Oops, something went wrong.