This is the official codebase of the paper
A Systematic Study of Joint Representation Learning on Protein Sequences and Structures [ArXiv]
Zuobai Zhang, Chuanrui Wang*, Minghao Xu*, Vijil Chenthamarakshan, Aurelie Lozano, Payel Das, Jian Tang
To explore the advantage of combining the advantages of sequence- and structure-based protein encoders, we conduct a comprehensive investigation into joint protein representation learning. Our study combines a state-of-the-art PLM (ESM-2) with three distinct structure encoders (GVP, GearNet, and CDConv). We introduce three fusion strategies—serial, parallel, and cross fusion—to combine sequence and structure representations.
We further explore six diverse pre-training techniques: (Residue Type Prediction, Distance Prediction, Angle Prediction, Dihedral Prediction, Multiview Contrast, SiamDiff), employing the optimal model from the aforementioned choices and leveraging pre-training on the AlphaFold Database.
You can find the pre-trained model weights here, including ESM-GearNet pre-trained with Multiview Contrast, Residue Type Prediction, Distance Prediction, Angle Prediction, Dihedral Prediction and SiamDiff.
You may install the dependencies via either conda or pip. Generally, ESM-GearNet works with Python 3.7/3.8 and PyTorch version >= 1.12.0.
conda install torchdrug pytorch=1.12.1 cudatoolkit=11.6 -c milagraph -c pytorch-lts -c pyg -c conda-forge
conda install easydict pyyaml -c conda-forge
conda install transformers==4.14.1 tokenizers==0.10.3 -c huggingface
pip install atom3d
pip install torch==1.12.1+cu116 -f https://download.pytorch.org/whl/lts/1.12/torch_lts.html
pip install torchdrug
pip install easydict pyyaml
pip install atom3d
pip install transformers==4.14.1 tokenizers==0.10.3
To reproduce the results of ESM-{GVP, GearNet, CDConv}, use the following command.
Alternatively, you may reset the gpus
parameter in configure files to switch to other GPUs. All the datasets will be automatically downloaded in the code.
It takes longer time to run the code for the first time due to the preprocessing time of the dataset.
# Run ESM-GearNet (serial fusion) on the Enzyme Comission dataset with 4 gpus
python -m torch.distributed.launch --nproc_per_node=4 script/downstream.py -c config/EC/esm_gearnet.yaml
# ESM-GearNet (parallel fusion)
python -m torch.distributed.launch --nproc_per_node=4 script/downstream.py -c config/EC/esm_gearnet_parallel.yaml
# ESM-GearNet (cross fusion)
python -m torch.distributed.launch --nproc_per_node=4 script/downstream.py -c config/EC/esm_gearnet_cross.yaml
# Run ESM-GearNet (serial fusion) on the Gene Ontology dataset
python -m torch.distributed.launch --nproc_per_node=4 script/downstream.py -c config/GO/esm_gearnet.yaml --branch MF
# Run ESM-GearNet (serial fusion) on the PSR dataset
python -m torch.distributed.launch --nproc_per_node=4 script/downstream.py -c config/PSR/esm_gearnet.yaml
# Run ESM-GearNet (serial fusion) on the MSP dataset
python -m torch.distributed.launch --nproc_per_node=4 script/downstream.py -c config/MSP/esm_gearnet.yaml
By default, we will use the AlphaFold Datase for pretraining. To pre-train ESM-GearNet with Multiview Contrast, use the following command. Similar, all the datasets will be automatically downloaded in the code and preprocessed for the first time you run the code.
# Run pre-training
python -m torch.distributed.launch --nproc_per_node=4 script/pretrain.py -c config/pretrain/mc_esm_gearnet.yaml
After pre-training, you can load the model weight from the saved checkpoint via the --ckpt
argument and then finetune the model on downstream tasks.
Remember to first uncomment the ``model_checkpoint: {{ ckpt }}` line in the config file.
python -m torch.distributed.launch --nproc_per_node=4 script/downstream.py -c config/EC/esm_gearnet.yaml --ckpt <path_to_your_model>
If you find this codebase useful in your research, please cite the following papers.
@article{zhang2023enhancing,
title={A Systematic Study of Joint Representation Learning on Protein Sequences and Structures},
author={Zhang, Zuobai and Wang, Chuanrui and Xu, Minghao and Chenthamarakshan, Vijil and Lozano, Aurelie and Das, Payel and Tang, Jian},
journal={arXiv preprint arXiv:2303.06275},
year={2023}
}