forked from facebookresearch/esm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_load_all.py
50 lines (44 loc) · 1.65 KB
/
test_load_all.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import pytest
from pathlib import Path
import torch
import esm
# Directly from hubconf.py
model_names = """
esm1_t6_43M_UR50S,
esm1_t12_85M_UR50S,
esm1_t34_670M_UR50S,
esm1_t34_670M_UR50D,
esm1_t34_670M_UR100,
esm1b_t33_650M_UR50S,
esm_msa1_t12_100M_UR50S,
esm_msa1b_t12_100M_UR50S,
esm1v_t33_650M_UR90S,
esm1v_t33_650M_UR90S_1,
esm1v_t33_650M_UR90S_2,
esm1v_t33_650M_UR90S_3,
esm1v_t33_650M_UR90S_4,
esm1v_t33_650M_UR90S_5,
esm_if1_gvp4_t16_142M_UR50,
"""
model_names = [mn.strip() for mn in model_names.strip(" ,\n").split(",")]
@pytest.mark.parametrize("model_name", model_names)
def test_load_hub_fwd_model(model_name: str) -> None:
model, alphabet = getattr(esm.pretrained, model_name)()
# batch_size = 2, seq_len = 3, tokens within vocab
dummy_inp = torch.tensor([[0, 1, 2], [3, 4, 5]])
if "esm_msa" in model_name:
dummy_inp = dummy_inp.unsqueeze(0)
output = model(dummy_inp) # dict
logits = output["logits"].squeeze(0)
assert logits.shape == (2, 3, len(alphabet))
@pytest.mark.parametrize("model_name", model_names)
def test_load_local(model_name: str) -> None:
# Assumes everything has already been loaded & cached.
local_path = Path.home() / ".cache/torch/hub/checkpoints" / (model_name + ".pt")
if model_name.endswith("esm1v_t33_650M_UR90S"):
return # skip; needs to get rerouted to specific instance
model, alphabet = esm.pretrained.load_model_and_alphabet_local(local_path)