-
Notifications
You must be signed in to change notification settings - Fork 38
/
Copy pathtest_distance_based_mask.py
76 lines (60 loc) · 2.4 KB
/
test_distance_based_mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/env python
from __future__ import print_function
import unittest
import numpy as np
import pytraj as pt
from pytraj.testing import aa_eq
from itertools import product
from utils import fn, tz2_trajin, tz2_top
class TestDistanceBasedMask(unittest.TestCase):
def test_atom_distance(self):
traj = pt.iterload(tz2_trajin, tz2_top)
top = traj.top
ref = traj[0]
# test for 1st frame
top.set_reference(ref)
ref.top = top
# all atoms within 5 Angtrom from :3@CA
indices = top.select(":3@CA <@5.0")
saved_indices = np.loadtxt(
fn("mask.tz2.dat"), skiprows=1, usecols=(1, ))
neighbors_smaller = pt.search_neighbors(
traj, mask=':3@CA <@5.0', frame_indices=[
0,
])
# subtract by '1' since cpptraj uses "1" as starting index for output
saved_indices = saved_indices - 1
aa_eq(indices, saved_indices)
aa_eq(neighbors_smaller.values, indices)
# re-calculate the distance
ca_indices = pt.select_atoms(':3@CA', traj.top)
all_pairs = list(product(ca_indices, indices))
distances = pt.tools.flatten(pt.distance(ref, all_pairs))
for dist in distances:
assert dist < 5.0, 'all distances must be smaller than 5.0 Angstrom'
# test larger
# why do we need to set reference frame again?
top.set_reference(ref)
indices_larger = top.select(":3@CA >@5.0")
all_pairs_larger = list(product(ca_indices, indices_larger))
distances = pt.tools.flatten(pt.distance(ref, all_pairs_larger))
for dist in distances:
assert dist > 5.0, 'all distances must be larger than 5.0 Angstrom'
# search_neighbors
neighbors_larger = pt.search_neighbors(
traj, mask=':3@CA >@5.0', frame_indices=[
0,
])
aa_eq(neighbors_larger.values, indices_larger)
def test_residue_distance(self):
traj = pt.iterload(tz2_trajin, tz2_top)
top = traj.top
ref = traj[0]
top.set_reference(ref)
ref.top = top
indices_smaler = pt.select_atoms(':3@CA <:5.0', top)
ca_indices = pt.select_atoms(':3@CA', traj.top)
all_pairs_smaller = list(product(ca_indices, indices_smaler))
pt.tools.flatten(pt.distance(ref, all_pairs_smaller))
if __name__ == "__main__":
unittest.main()