-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict_edgenorm.py
103 lines (85 loc) · 2.91 KB
/
predict_edgenorm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import datasets
import torch.utils.data
from config import config
import models as models
import torch_utils
import torch.optim
import timm.scheduler
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import global_data
from torch.utils.data.distributed import DistributedSampler
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import os
import torch.distributed as dist
import pickle
import common_utils
run_id = 'dropout0.01_decay1_0.97_h16_hidden512_prednorm_rateloss_valid-train_l16'
iepoch = 28
device = 'cuda:2'
def predict():
dataset = datasets.SimplePCQM4MDataset(path=config['middle_data_path'], split_name='all', rotate=False)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=config['batch_size'],
num_workers=8,
collate_fn=datasets.collate_fn,
shuffle=False
)
torch.cuda.set_device(device)
torch.cuda.empty_cache()
model = models.MoleculePairDistPredictor(config)
print('num of parameters: {0}'.format(np.sum([p.numel() for p in model.parameters()])))
model_save_path = os.path.join('models_valid', run_id)
sd = torch.load(os.path.join(model_save_path, f'epoch_{iepoch:03d}.pt'), map_location='cpu')
sd = {k[7:]: v for k, v in sd.items()}
# print(sd.keys())
model.load_state_dict(sd)
model.to(device)
model_save_path = os.path.join('models_valid', run_id)
scores_list = []
model.eval()
for batch in tqdm(loader):
graph, y = batch
graph = torch_utils.batch_to_device(graph, device)
with torch.no_grad():
scores = model(
graph['atom_feat_cate'],
graph['atom_feat_float'],
graph['atom_mask'],
graph['bond_index'],
graph['bond_feat_cate'],
graph['bond_feat_float'],
graph['bond_mask'],
graph['structure_feat_cate'],
graph['structure_feat_float'],
graph['triplet_feat_cate'])
pass
scores = scores.detach().cpu().numpy()
num_atom = graph['atom_mask'].sum(dim=1).detach().cpu().numpy().astype('int64')
for i, s in enumerate(scores):
scores_list.append(s[:num_atom[i], :num_atom[i]])
pass
pass
return scores_list
pass
def process_fn(param):
p, i = param
data_path = os.path.join(config['middle_data_path'], 'data2')
filename = os.path.join(data_path, format(i // 1000, '04d'), format(i, '07d') + '.pkl')
g, y = common_utils.load_obj(filename)
g['predict_pair_dist'] = p
common_utils.save_obj((g, y), filename)
pass
def main():
preds = predict()
pool = mp.Pool()
params = [(preds[i], i) for i in range(len(preds))]
list(pool.imap_unordered(process_fn, tqdm(params), chunksize=1024))
pool.close()
pass
if __name__ == '__main__':
main()
pass