This repository was archived by the owner on Aug 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 669
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 96e937c
Showing
19 changed files
with
1,427 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
[flake8] | ||
max-line-length = 99 | ||
ignore = E203,W503 | ||
exclude = | ||
.git, | ||
__pycache__, | ||
build, | ||
dist, | ||
experimental | ||
third_party |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Compiler Output # | ||
################### | ||
*.py[cod] | ||
*.so | ||
*.o | ||
*.exe | ||
*.class | ||
|
||
# Folders # | ||
########### | ||
bin/ | ||
build/ | ||
dist/ | ||
local/ | ||
tmp/ | ||
__pycache__/ | ||
*.egg-info/ | ||
.ipynb_checkpoints/ | ||
.vscode/ | ||
|
||
# Junk # | ||
######## | ||
.DS_Store* | ||
.*.swp | ||
*.swp | ||
*.log | ||
*~ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Code of Conduct | ||
=============== | ||
|
||
Facebook has adopted a Code of Conduct that we expect project participants to adhere to. Please `read the full text`__ so that you can understand what actions will and will not be tolerated. | ||
|
||
__ https://code.facebook.com/codeofconduct |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Contributing to esm | ||
We want to make contributing to this project as easy and transparent as | ||
possible. | ||
|
||
## Pull Requests | ||
We actively welcome your pull requests. | ||
|
||
1. Fork the repo and create your branch from `master`. | ||
2. If you've added code that should be tested, add tests. | ||
3. If you've changed APIs, update the documentation. | ||
4. Ensure the test suite passes. | ||
5. Make sure your code lints. | ||
6. If you haven't already, complete the Contributor License Agreement ("CLA"). | ||
|
||
## Contributor License Agreement ("CLA") | ||
In order to accept your pull request, we need you to submit a CLA. You only need | ||
to do this once to work on any of Facebook's open source projects. | ||
|
||
Complete your CLA here: <https://code.facebook.com/cla> | ||
|
||
## Issues | ||
We use GitHub issues to track public bugs. Please ensure your description is | ||
clear and has sufficient instructions to be able to reproduce the issue. | ||
|
||
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe | ||
disclosure of security bugs. In those cases, please go through the process | ||
outlined on that page and do not file a public issue. | ||
|
||
## License | ||
By contributing to icp-block-mdp, you agree that your contributions will be licensed | ||
under the LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) Facebook, Inc. and its affiliates. | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
====================================================== | ||
Evolutionary Scale Modeling (esm): Pretrained language models for proteins | ||
====================================================== | ||
|
||
This repository contains a PyTorch implementation of the transformer protein language models in | ||
`"Biological structure and function emerge from scaling unsupervised learning | ||
to 250 million protein sequences" (Rives et al., 2019)`__ | ||
from Facebook AI Research, along with pre-trained models. | ||
|
||
__ https://doi.org/10.1101/622803 | ||
|
||
Quickstart | ||
========== | ||
|
||
As a prerequisite, you must have PyTorch 1.5 or later installed to use this repository. | ||
A cuda device is optional and will be auto-detected. | ||
|
||
Use this one-liner for installation: | ||
|
||
.. code-block:: bash | ||
$ pip install git+https://github.com/facebookresearch/esm.git | ||
Then, you can load and use a pretrained model as follows: | ||
|
||
.. code-block:: python | ||
import torch | ||
import esm | ||
# Load 34 layer model | ||
model, alphabet = esm.pretrained.esm1_t34_670M_UR50S() | ||
batch_converter = alphabet.get_batch_converter() | ||
# Prepare data (two protein sequences) | ||
data = [("protein1", "MYLYQKIKN"), ("protein2", "MNAKYD")] | ||
batch_labels, batch_strs, batch_tokens = batch_converter(data) | ||
# Extract per-residue representations (on CPU) | ||
with torch.no_grad(): | ||
results = model(batch_tokens, repr_layers=[34]) | ||
token_representations = results["representations"][34] | ||
# Generate per-sequence representations via averaging | ||
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1. | ||
sequence_representations = [] | ||
for i, (_, seq) in enumerate(data): | ||
sequence_representations.append(token_representations[i, 1:len(seq) + 1].mean(0)) | ||
We also support PyTorch Hub, which removes the need to clone and/or install this repository yourself: | ||
|
||
.. code-block:: python | ||
import torch | ||
model, alphabet = torch.hub.load("facebookresearch/esm", "esm1_t34_670M_UR50S") | ||
FASTA representation extractor | ||
------------------------------ | ||
|
||
For your convenience, we have provided a script that efficiently extracts representations in bulk from a FASTA file: | ||
|
||
.. code-block:: bash | ||
# Extract final-layer representations for a FASTA file from a 34-layer model | ||
$ python extract.py esm1_t34_670M_UR50S examples/some_proteins.fasta my_reprs/ \ | ||
--repr_layers 0 32 34 --include-per-tok --include-mean | ||
# my_reprs/ now contains one ".pt" file per FASTA sequence; use torch.load() to load them | ||
# extract.py has flags that determine what's included in the ".pt" file: | ||
# --repr-layers (default: final only) selects which layers to include representations from. | ||
# --include-per-tok includes the full sequence, with an embedding per amino acid (seq_len x hidden_dim). | ||
# --include-mean includes the embeddings per layer, averaged over the full sequence. | ||
# --include-bos includes the embeddings from the beginning-of-sequence token. | ||
Available models | ||
================ | ||
|
||
The following table lists the pretrained models available for use. | ||
Names are self-explanatory corresponding to Table 1 in the updated paper | ||
(number of layers, number of params, training dataset). | ||
|
||
* esm1_t34_670M_UR50S -- this is the best model and should be go-to. | ||
* esm1_t34_670M_UR50D | ||
* esm1_t34_670M_UR100 | ||
* esm1_t12_85M_UR50S | ||
* esm1_t6_43M_UR50S | ||
|
||
Reference | ||
========= | ||
|
||
If you find the model useful in your research, we ask that you cite the | ||
following paper: | ||
|
||
.. code-block:: bibtex | ||
@article{rives2019biological, | ||
author={Rives, Alexander and Meier, Joshua and Sercu, Tom and Goyal, Siddharth and Lin, Zeming and Guo, Demi and Ott, Myle and Zitnick, C. Lawrence and Ma, Jerry and Fergus, Rob}, | ||
title={Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences}, | ||
year={2019}, | ||
doi={10.1101/622803}, | ||
url={https://www.biorxiv.org/content/10.1101/622803v3}, | ||
journal={bioRxiv} | ||
} | ||
Additionally, much of this code hails from the excellent `fairseq`__ sequence modeling framework; we have released this standalone model to facilitate more lightweight and flexible usage. We encourage those who wish to pretrain protein language models from scratch to use fairseq. | ||
|
||
__ https://github.com/pytorch/fairseq | ||
|
||
License | ||
======= | ||
|
||
This source code is licensed under the MIT license found in the ``LICENSE`` file | ||
in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from .version import version as __version__ # noqa | ||
|
||
from .data import Alphabet, BatchConverter, FastaBatchedDataset # noqa | ||
from .model import ProteinBertModel # noqa | ||
from . import pretrained # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
proteinseq_toks = { | ||
'toks': ['L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O'] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
|
||
|
||
class FastaBatchedDataset(object): | ||
def __init__(self, sequence_labels, sequence_strs): | ||
self.sequence_labels = list(sequence_labels) | ||
self.sequence_strs = list(sequence_strs) | ||
|
||
@classmethod | ||
def from_file(cls, fasta_file): | ||
sequence_labels, sequence_strs = [], [] | ||
cur_seq_label = None | ||
buf = [] | ||
|
||
def _flush_current_seq(): | ||
nonlocal cur_seq_label, buf | ||
if cur_seq_label is None: | ||
return | ||
sequence_labels.append(cur_seq_label) | ||
sequence_strs.append("".join(buf)) | ||
cur_seq_label = None | ||
buf = [] | ||
|
||
with open(fasta_file, "r") as infile: | ||
for line_idx, line in enumerate(infile): | ||
if line.startswith(">"): # label line | ||
_flush_current_seq() | ||
line = line[1:].strip() | ||
if len(line) > 0: | ||
cur_seq_label = line | ||
else: | ||
cur_seq_label = f"seqnum{line_idx:09d}" | ||
else: # sequence line | ||
buf.append(line.strip()) | ||
|
||
_flush_current_seq() | ||
|
||
assert len(set(sequence_labels)) == len(sequence_labels) | ||
|
||
return cls(sequence_labels, sequence_strs) | ||
|
||
def __len__(self): | ||
return len(self.sequence_labels) | ||
|
||
def __getitem__(self, idx): | ||
return self.sequence_labels[idx], self.sequence_strs[idx] | ||
|
||
def get_batch_indices(self, toks_per_batch, extra_toks_per_seq=0): | ||
sizes = [(len(s), i) for i, s in enumerate(self.sequence_strs)] | ||
sizes.sort() | ||
batches = [] | ||
buf = [] | ||
max_len = 0 | ||
|
||
def _flush_current_buf(): | ||
nonlocal max_len, buf | ||
if len(buf) == 0: | ||
return | ||
batches.append(buf) | ||
buf = [] | ||
max_len = 0 | ||
|
||
for sz, i in sizes: | ||
sz += extra_toks_per_seq | ||
if max(sz, max_len) * (len(buf) + 1) > toks_per_batch: | ||
_flush_current_buf() | ||
max_len = max(max_len, sz) | ||
buf.append(i) | ||
|
||
_flush_current_buf() | ||
return batches | ||
|
||
|
||
class Alphabet(object): | ||
def __init__(self, standard_toks): | ||
self.standard_toks = list(standard_toks) | ||
|
||
self.all_toks = ["<null_0>", "<pad>", "<eos>", "<unk>"] | ||
self.all_toks += self.standard_toks | ||
for i in range((8 - (len(self.all_toks) % 8)) % 8): | ||
self.all_toks.append(f"<null_{i + 1}>") | ||
self.all_toks += ["<cls>", "<mask>", "<sep>"] | ||
|
||
self.tok_to_idx = {tok: i for i, tok in enumerate(self.all_toks)} | ||
|
||
self.padding_idx = self.get_idx("<pad>") | ||
self.cls_idx = self.get_idx("<cls>") | ||
self.mask_idx = self.get_idx("<mask>") | ||
self.sep_idx = self.get_idx("<sep>") | ||
|
||
def __len__(self): | ||
return len(self.all_toks) | ||
|
||
def get_idx(self, tok): | ||
return self.tok_to_idx[tok] | ||
|
||
def get_tok(self, ind): | ||
return self.all_toks[ind] | ||
|
||
def to_dict(self): | ||
return {"toks": self.toks} | ||
|
||
def get_batch_converter(self): | ||
return BatchConverter(self) | ||
|
||
@classmethod | ||
def from_dict(cls, d): | ||
return cls(standard_toks=d["toks"]) | ||
|
||
|
||
class BatchConverter(object): | ||
"""Callable to convert an unproceseed (labels + strings) batch to a | ||
processed (labels + tensor) batch. | ||
""" | ||
|
||
def __init__(self, alphabet): | ||
self.alphabet = alphabet | ||
|
||
def __call__(self, raw_batch): | ||
batch_size = len(raw_batch) | ||
max_len = max(len(seq_str) for _, seq_str in raw_batch) | ||
tokens = torch.empty((batch_size, max_len + 1), dtype=torch.int64) | ||
tokens.fill_(self.alphabet.padding_idx) | ||
labels = [] | ||
strs = [] | ||
|
||
for i, (label, seq_str) in enumerate(raw_batch): | ||
labels.append(label) | ||
strs.append(seq_str) | ||
tokens[i, 0] = self.alphabet.cls_idx | ||
seq = torch.tensor([self.alphabet.get_idx(s) for s in seq_str], dtype=torch.int64) | ||
tokens[i, 1 : len(seq_str) + 1] = seq | ||
|
||
return labels, strs, tokens |
Oops, something went wrong.