forked from Kuhlman-Lab/ThermoMPNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransfer_model.py
155 lines (116 loc) · 6.01 KB
/
transfer_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
import torch.nn as nn
from protein_mpnn_utils import ProteinMPNN, tied_featurize
from model_utils import featurize
import os
HIDDEN_DIM = 128
EMBED_DIM = 128
VOCAB_DIM = 21
ALPHABET = 'ACDEFGHIKLMNPQRSTVWYX'
MLP = True
SUBTRACT_MUT = True
def get_protein_mpnn(cfg, version='v_48_020.pt'):
"""Loading Pre-trained ProteinMPNN model for structure embeddings"""
hidden_dim = 128
num_layers = 3
model_weight_dir = os.path.join(cfg.platform.thermompnn_dir, 'vanilla_model_weights')
checkpoint_path = os.path.join(model_weight_dir, version)
# checkpoint_path = "vanilla_model_weights/v_48_020.pt"
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model = ProteinMPNN(ca_only=False, num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim,
num_encoder_layers=num_layers, num_decoder_layers=num_layers, k_neighbors=checkpoint['num_edges'], augment_eps=0.0)
if cfg.model.load_pretrained:
model.load_state_dict(checkpoint['model_state_dict'])
if cfg.model.freeze_weights:
model.eval()
# freeze these weights for transfer learning
for param in model.parameters():
param.requires_grad = False
return model
class TransferModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.hidden_dims = list(cfg.model.hidden_dims)
self.subtract_mut = cfg.model.subtract_mut
self.num_final_layers = cfg.model.num_final_layers
self.lightattn = cfg.model.lightattn if 'lightattn' in cfg.model else False
if 'decoding_order' not in self.cfg:
self.cfg.decoding_order = 'left-to-right'
self.prot_mpnn = get_protein_mpnn(cfg)
EMBED_DIM = 128
HIDDEN_DIM = 128
hid_sizes = [ HIDDEN_DIM*self.num_final_layers + EMBED_DIM ]
hid_sizes += self.hidden_dims
hid_sizes += [ VOCAB_DIM ]
print('MLP HIDDEN SIZES:', hid_sizes)
if self.lightattn:
print('Enabled LightAttention')
self.light_attention = LightAttention(embeddings_dim=HIDDEN_DIM*self.num_final_layers + EMBED_DIM )
self.both_out = nn.Sequential()
for sz1, sz2 in zip(hid_sizes, hid_sizes[1:]):
self.both_out.append(nn.ReLU())
self.both_out.append(nn.Linear(sz1, sz2))
self.ddg_out = nn.Linear(1, 1)
def forward(self, pdb, mutations, tied_feat=True):
device = next(self.parameters()).device
X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, visible_list_list, masked_list_list, masked_chain_length_list_list, chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, bias_by_res_all, tied_beta = tied_featurize([pdb[0]], device, None, None, None, None, None, None, ca_only=False)
# getting ProteinMPNN structure embeddings
all_mpnn_hid, mpnn_embed, _ = self.prot_mpnn(X, S, mask, chain_M, residue_idx, chain_encoding_all, None)
if self.num_final_layers > 0:
mpnn_hid = torch.cat(all_mpnn_hid[:self.num_final_layers], -1)
out = []
for mut in mutations:
inputs = []
if mut is None:
out.append(None)
continue
aa_index = ALPHABET.index(mut.mutation)
wt_aa_index = ALPHABET.index(mut.wildtype)
if self.num_final_layers > 0:
hid = mpnn_hid[0][mut.position] # MPNN hidden embeddings at mut position
inputs.append(hid)
embed = mpnn_embed[0][mut.position] # MPNN seq embeddings at mut position
inputs.append(embed)
# concatenating hidden layers and embeddings
lin_input = torch.cat(inputs, -1)
# passing vector through lightattn
if self.lightattn:
lin_input = torch.unsqueeze(torch.unsqueeze(lin_input, -1), 0)
lin_input = self.light_attention(lin_input, mask)
both_input = torch.unsqueeze(self.both_out(lin_input), -1)
ddg_out = self.ddg_out(both_input)
if self.subtract_mut:
ddg = ddg_out[aa_index][0] - ddg_out[wt_aa_index][0]
else:
ddg = ddg_out[aa_index][0]
out.append({
"ddG": torch.unsqueeze(ddg, 0),
})
return out, None
class LightAttention(nn.Module):
"""Source:
Hannes Stark et al. 2022
https://github.com/HannesStark/protein-localization/blob/master/models/light_attention.py
"""
def __init__(self, embeddings_dim=1024, output_dim=11, dropout=0.25, kernel_size=9, conv_dropout: float = 0.25):
super(LightAttention, self).__init__()
self.feature_convolution = nn.Conv1d(embeddings_dim, embeddings_dim, kernel_size, stride=1,
padding=kernel_size // 2)
self.attention_convolution = nn.Conv1d(embeddings_dim, embeddings_dim, kernel_size, stride=1,
padding=kernel_size // 2)
self.softmax = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(conv_dropout)
def forward(self, x: torch.Tensor, mask, **kwargs) -> torch.Tensor:
"""
Args:
x: [batch_size, embeddings_dim, sequence_length] embedding tensor that should be classified
mask: [batch_size, sequence_length] mask corresponding to the zero padding used for the shorter sequecnes in the batch. All values corresponding to padding are False and the rest is True.
Returns:
classification: [batch_size,output_dim] tensor with logits
"""
o = self.feature_convolution(x) # [batch_size, embeddings_dim, sequence_length]
o = self.dropout(o) # [batch_gsize, embeddings_dim, sequence_length]
attention = self.attention_convolution(x) # [batch_size, embeddings_dim, sequence_length]
o1 = o * self.softmax(attention)
return torch.squeeze(o1)