-
Notifications
You must be signed in to change notification settings - Fork 1
/
check_dist_diff.py
56 lines (48 loc) · 1.51 KB
/
check_dist_diff.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import datasets2 as datasets
from config import config
from tqdm import tqdm
import numpy as np
import torch
import scipy.spatial.distance
# torch.multiprocessing.set_sharing_strategy('file_system')
def load_dist(name):
dataset = datasets.SimplePCQM4MDataset(
path=config['middle_data_path'], split_name='train', rotate=False, path_atom_map=None, data_path_name=name)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=1024,
num_workers=32,
collate_fn=datasets.collate_fn
)
dists = []
bar = tqdm(loader)
for batch in bar:
g, y = batch
if name == 'data':
# dist = g['structure_feat_float'].squeeze(-1)
# dist = torch.cdist(xyz, xyz, p=2)
dist = g['predict_pair_dist'].squeeze(-1)
else:
xyz = g['xyz'].to('cuda:0')
dist = torch.cdist(xyz, xyz, p=2)
pass
num_atom = torch.sum(g['atom_mask'], dim=1).long()
for i, d in enumerate(dist.cpu().numpy()):
d = d[:num_atom[i], :num_atom[i]]
dists.append(d)
pass
pass
bar.close()
return dists
pass
if __name__ == '__main__':
dist_rdkit = load_dist('data')
dist_sdf = load_dist('data2')
diffs = []
bar = tqdm(zip(dist_rdkit, dist_sdf))
for d1, d2 in bar:
dist_diff = np.mean(np.abs(d1-d2)/(d2+1e-12))
diffs.append(dist_diff)
bar.set_postfix({'dist_diff': np.mean(diffs)})
pass
bar.close()