forked from facebookresearch/esm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_readme.py
162 lines (131 loc) · 6.19 KB
/
test_readme.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
import subprocess
import tempfile
import requests
import shutil
from pathlib import Path
import torch
import esm
def test_readme_1():
import torch
model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")
def test_readme_2():
import torch
import esm
# Load ESM-1b model
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter()
model.eval() # disables dropout for deterministic results
# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
(
"protein2 with mask",
"KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"
),
(
"protein3",
"K A <mask> I S Q"
),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
# Extract per-residue representations (on CPU)
with torch.no_grad():
results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]
# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, (_, seq) in enumerate(data):
sequence_representations.append(token_representations[i, 1 : len(seq) + 1].mean(0))
# Look at the unsupervised self-attention map contact predictions
import matplotlib.pyplot as plt
for (_, seq), attention_contacts in zip(data, results["contacts"]):
plt.matshow(attention_contacts[: len(seq), : len(seq)])
plt.title(seq)
plt.show()
def _run_py_cmd(cmd, **kwargs):
this_python = sys.executable
cmd.replace("python", this_python)
subprocess.run(cmd, shell=True, check=True, **kwargs)
def test_readme_3():
# NOTE modification on copy paste from README for speed:
# * some_proteins -> few_proteins (subset)
# * I computed reference values a while ago for: esm1b -> esm1 and layers 33 -> 34
cmd = """
python scripts/extract.py esm1_t34_670M_UR50S examples/data/few_proteins.fasta examples/data/few_proteins_emb_esm1/ \
--repr_layers 0 33 34 --include mean per_tok
"""
_run_py_cmd(cmd)
confirm_all_tensors_equal(
"examples/few_proteins_emb_esm1/",
"https://dl.fbaipublicfiles.com/fair-esm/tests/some_proteins_emb_esm1_t34_670M_UR50S_ref",
)
def assert_pt_file_equal(f, fref):
a = torch.load(f)
b = torch.load(fref)
# set intersection of dict keys:
which_layers = a["representations"].keys() & b["representations"].keys()
assert which_layers, "Expected at least one layer appearing in both dumps"
for layer in which_layers:
assert torch.allclose(a["representations"][layer], b["representations"][layer], atol=1e-3)
def confirm_all_tensors_equal(local_dir: str, ref_dir: str) -> None:
# TODO use pytest built-in fixtures for tmp_path https://docs.pytest.org/en/6.2.x/fixture.html#fixtures
for fn in Path(local_dir).glob("*.pt"):
with tempfile.NamedTemporaryFile(mode="w+b", prefix=fn.name) as f:
ref_url = f"{ref_dir}/{fn.name}"
with requests.get(ref_url, stream=True) as r:
shutil.copyfileobj(r.raw, f)
f.seek(0)
assert_pt_file_equal(fn, f)
def test_msa_transformers():
_test_msa_transformer(*esm.pretrained.esm_msa1_t12_100M_UR50S())
_test_msa_transformer(*esm.pretrained.esm_msa1b_t12_100M_UR50S())
def _test_msa_transformer(model, alphabet):
batch_converter = alphabet.get_batch_converter()
# Make an "MSA" of size 3
data = [
("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
("protein2", "MHTVRQSRLKSIVRILEMSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
("protein3", "MHTVRQSRLKSIVRILEMSKEPVSGAQL---LSVSRQVIVQDIAYLRSLGYNIVAT----VLAGG"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
with torch.no_grad():
results = model(batch_tokens, repr_layers=[12], return_contacts=True)
token_representations = results["representations"][12]
assert token_representations.shape == (1, 3, 66, 768)
def test_variant_readme_1():
cmd = """
python predict.py \
--model-location esm1v_t33_650M_UR90S_1 esm1v_t33_650M_UR90S_2 esm1v_t33_650M_UR90S_3 esm1v_t33_650M_UR90S_4 esm1v_t33_650M_UR90S_5 \
--sequence HPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW \
--dms-input ./data/BLAT_ECOLX_Ranganathan2015.csv \
--mutation-col mutant \
--dms-output ./data/BLAT_ECOLX_Ranganathan2015_labeled.csv \
--offset-idx 24 \
--scoring-strategy wt-marginals
"""
_run_py_cmd(cmd, cwd="examples/variant-prediction/")
def test_variant_readme_2():
cmd = """
python predict.py \
--model-location esm_msa1b_t12_100M_UR50S \
--sequence HPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW \
--dms-input ./data/BLAT_ECOLX_Ranganathan2015.csv \
--mutation-col mutant \
--dms-output ./data/BLAT_ECOLX_Ranganathan2015_labeled.csv \
--offset-idx 24 \
--scoring-strategy masked-marginals \
--msa-path ./data/BLAT_ECOLX_1_b0.5.a3m
"""
_run_py_cmd(cmd, cwd="examples/variant-prediction/")
if __name__ == "__main__":
confirm_all_tensors_equal(
"examples/few_proteins_emb_esm1/",
"https://dl.fbaipublicfiles.com/fair-esm/tests/some_proteins_emb_esm1_t34_670M_UR50S_ref/",
)