-
Notifications
You must be signed in to change notification settings - Fork 23
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
278 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
#!/usr/bin/env python3 | ||
# encoding: utf-8 | ||
|
||
|
||
import unittest | ||
|
||
from arc.imports import settings | ||
import arc.ts_split.split as split | ||
from arc.species.converter import check_xyz_dict, xyz_to_dmat | ||
|
||
servers = settings['servers'] | ||
|
||
|
||
class TestTSSplit(unittest.TestCase): | ||
""" | ||
Contains unit tests for ARC's split module | ||
""" | ||
|
||
def test_bonded(self): | ||
"""Test bonded""" | ||
self.assertTrue(split.bonded(1.0, 'C', 'C')) | ||
self.assertTrue(split.bonded(1.0, 'C', 'H')) | ||
self.assertTrue(split.bonded(1.0, 'H', 'C')) | ||
self.assertTrue(split.bonded(1.0, 'H', 'H')) | ||
self.assertFalse(split.bonded(1.7, 'C', 'C')) | ||
self.assertFalse(split.bonded(1.5, 'C', 'H')) | ||
self.assertFalse(split.bonded(1.5, 'H', 'C')) | ||
self.assertFalse(split.bonded(1.5, 'H', 'H')) | ||
|
||
def test_get_adjlist_from_dmat(self): | ||
"""Test get_adjlist_from_dmat""" | ||
symbols = ('O', 'N', 'C', 'H', 'H', 'S', 'H') | ||
d = [[0., 1.44678738, 2.1572649, 3.07926623, 2.69780089, 1.74022888, 1.95867823], | ||
[1.44678738, 0., 1.47078693, 2.16662322, 2.12283495, 2.34209263, 1.02337844], | ||
[2.1572649, 1.47078693, 0., 1.09133324, 1.09169397, 1.82651322, 2.02956962], | ||
[3.07926623, 2.16662322, 1.09133324, 0., 1.80097071, 2.51409166, 2.30585633], | ||
[2.69780089, 2.12283495, 1.09169397, 1.80097071, 0., 2.45124337, 2.92889793], | ||
[1.74022888, 2.34209263, 1.82651322, 2.51409166, 2.45124337, 0., 2.68310024], | ||
[1.95867823, 1.02337844, 2.02956962, 2.30585633, 2.92889793, 2.68310024, 0.]] | ||
adjlist = split.get_adjlist_from_dmat(dmat=d, symbols=symbols, h=3, a=2, b=5) # b is incorrect chemically | ||
self.assertEqual(adjlist, {0: [1, 5], 1: [0, 2, 6], 2: [1, 4, 5, 3], 4: [2], 5: [0, 2, 3], 6: [1], 3: [2, 5]}) | ||
|
||
xyz = """ C -3.80799396 1.05904061 0.12143410 | ||
H -3.75776386 0.09672979 -0.34366835 | ||
H -3.24934849 1.76454448 -0.45741718 | ||
H -4.82886508 1.37420677 0.17961125 | ||
C -3.21502590 0.97505234 1.54021348 | ||
H -3.26525696 1.93736874 2.00531040 | ||
H -3.77366533 0.26954471 2.11907272 | ||
C -1.74572494 0.52144864 1.45646938 | ||
H -1.18708880 1.22694232 0.87759054 | ||
H -1.69550074 -0.44087971 0.99139265 | ||
O -0.57307243 0.35560699 4.26172088 | ||
H -1.12770789 0.43395779 2.93512192 | ||
O 0.45489302 1.17807207 4.35811043 | ||
H 1.12427554 0.93029226 3.71613651""" | ||
xyz_dict = check_xyz_dict(xyz) | ||
dmat = xyz_to_dmat(xyz_dict) | ||
adjlist = split.get_adjlist_from_dmat(dmat=dmat, symbols=xyz_dict['symbols'], h=11, a=7, b=10) | ||
self.assertEqual(adjlist, | ||
{0: [1, 2, 3, 4], | ||
1: [0], | ||
2: [0], | ||
3: [0], | ||
4: [0, 5, 6, 7], | ||
5: [4], | ||
6: [4], | ||
7: [4, 8, 9, 11], | ||
8: [7], | ||
9: [7], | ||
10: [12, 11], | ||
12: [10, 13], | ||
13: [12], | ||
11: [7, 10]}) | ||
|
||
def test_iterative_dfs(self): | ||
"""Test iterative_dfs""" | ||
adjlist = {0: [1, 2, 3, 4], 1: [0], 2: [0], 3: [0], 4: [0, 5, 6, 7], 5: [4], 6: [4], 7: [4, 8, 9, 11], 8: [7], | ||
9: [7], 10: [12, 11], 12: [10, 13], 13: [12], 11: [7, 10]} | ||
g1 = split.iterative_dfs(adjlist=adjlist, start=10, border=11) | ||
self.assertEqual(g1, [11, 10, 12, 13]) | ||
g2 = split.iterative_dfs(adjlist=adjlist, start=7, border=11) | ||
self.assertEqual(g2, [11, 7, 9, 8, 4, 6, 5, 0, 3, 2, 1]) | ||
|
||
|
||
|
||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main(testRunner=unittest.TextTestRunner(verbosity=2)) | ||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
""" | ||
TS Split | ||
""" | ||
|
||
from typing import Dict, List, Tuple, Union | ||
import numpy as np | ||
from collections import deque | ||
|
||
from arc.common import SINGLE_BOND_LENGTH | ||
from arc.species.converter import xyz_to_dmat, translate_to_center_of_mass | ||
|
||
|
||
MAX_LENGTH = 3.0 | ||
|
||
|
||
def get_group_xyzs_and_key_indices_from_ts(xyz: dict, | ||
a: int, | ||
b: int, | ||
h: int, | ||
) -> Tuple[dict, dict, dict]: | ||
""" | ||
Get the two corresponding XYZs of groups in an H abstraction TS based on the H atom being abstracted. | ||
Args: | ||
xyz (dict): The TS xyz. | ||
a (int): The index of one of the heavy atoms H is connected to. | ||
b (int): The index of the other heavy atoms H is connected to. | ||
h (int): The index of the H atom being abstract | ||
Returns: | ||
Tuple[dict, dict, dict]: | ||
- xyz of group 1 | ||
- xyz of group 2 | ||
- Keys are 'g1_a', 'g1_h', 'g2_a', 'g2_h', values are atom indices i the returned group xyzs. | ||
""" | ||
g1, g2 = divide_h_abs_ts_int_groups(xyz, a, b, h) | ||
g1_xyz, g1_map = split_xyz_by_indices(xyz=xyz, indices=g1) | ||
g2_xyz, g2_map = split_xyz_by_indices(xyz=xyz, indices=g2) | ||
index_dict = {'g1_a': g1_map[a], 'g1_h': g1_map[h], 'g2_a': g2_map[b], 'g2_h': g2_map[h]} | ||
return g1_xyz, g2_xyz, index_dict | ||
|
||
|
||
def split_xyz_by_indices(xyz: dict, | ||
indices: List[int], | ||
) -> Tuple[dict, Dict[int, int]]: | ||
""" | ||
Split an XYZ dictionary by indices. | ||
Also, map the indices in to_map_indices to the new indices in the returned XYZ. | ||
Args: | ||
xyz (dict): The XYZ dictionary. | ||
indices (List[int]): The indices to split by. | ||
Returns: | ||
Tuple[dict, dict[int, int]]: | ||
- The split XYZ dictionary. | ||
- The new indices of the atoms in to_map_indices in the returned XYZ. Keys are the original indices. | ||
""" | ||
new_xyz = dict() | ||
new_xyz['symbols'] = tuple(symbol for i, symbol in enumerate(xyz['symbols']) if i in indices) | ||
new_xyz['isotopes'] = tuple(isotope for i, isotope in enumerate(xyz['isotopes']) if i in indices) | ||
new_xyz['coords'] = tuple(coord for i, coord in enumerate(xyz['coords']) if i in indices) | ||
mapped_index = 0 | ||
map_ = dict() | ||
for i in range(len(xyz['symbols'])): | ||
if i in indices: | ||
map_[i] = mapped_index | ||
mapped_index += 1 | ||
new_xyz = translate_to_center_of_mass(new_xyz) | ||
return new_xyz, map_ | ||
|
||
|
||
def divide_h_abs_ts_int_groups(xyz: dict, | ||
a: int, | ||
b: int, | ||
h: int, | ||
) -> Tuple[List[int], List[int]]: | ||
""" | ||
Divide the atoms in the TS into two groups based on the H atom being abstracted. | ||
Get the indices of the atoms in the two groups, each includes the abstracted H (so R1H, R2H). | ||
""" | ||
dmat = xyz_to_dmat(xyz) | ||
symbols = xyz['symbols'] | ||
adjlist = get_adjlist_from_dmat(dmat, symbols, h, a, b) | ||
g1 = iterative_dfs(adjlist, start=a, border=h) | ||
g2 = iterative_dfs(adjlist, start=b, border=h) | ||
return g1, g2 | ||
|
||
|
||
def get_adjlist_from_dmat(dmat: Union[np.ndarray, list], | ||
symbols: Tuple[str, ...], | ||
h: int, | ||
a: int, | ||
b: int, | ||
) -> Dict[int, List[int]]: | ||
""" | ||
Get an adjacency list from a DMat. | ||
Args: | ||
dmat (np.ndarray): The distance matrix. | ||
symbols (Tuple[str]): THe chemical elements. | ||
h (int): The index of the H atom being abstracted (all indices are 0-indexed) | ||
a (int): The index of one of the heavy atoms H is connected to. | ||
b (int): The index of the other heavy atoms H is connected to. | ||
Returns: | ||
Dict[int, List[int]]: The adjlist. | ||
""" | ||
adjlist = dict() | ||
for atom_1 in range(len(symbols)): | ||
if atom_1 == h: | ||
continue | ||
for atom_2 in range(len(symbols)): | ||
if atom_2 in [h, atom_1]: | ||
continue | ||
if dmat[atom_1][atom_2] <= MAX_LENGTH: | ||
|
||
if bonded(dmat[atom_1][atom_2], symbols[atom_1], symbols[atom_2]): | ||
if atom_1 not in adjlist: | ||
adjlist[atom_1] = list() | ||
adjlist[atom_1].append(atom_2) | ||
adjlist[h] = [a, b] | ||
adjlist[a].append(h) | ||
adjlist[b].append(h) | ||
return adjlist | ||
|
||
|
||
def bonded(distance: float, s1: str, s2: str) -> bool: | ||
""" | ||
Determine whether two atoms are bonded based on their distance and chemical symbols. | ||
Args: | ||
distance (float): The distance between the atoms. | ||
s1 (str): The chemical symbol of the first atom. | ||
s2 (str): The chemical symbol of the second atom. | ||
Returns: | ||
bool: Whether the atoms are bonded. | ||
""" | ||
bond_key = f'{s1}_{s2}' | ||
ref_dist = SINGLE_BOND_LENGTH.get(bond_key, None) or SINGLE_BOND_LENGTH.get(f'{s2}_{s1}', None) | ||
if ref_dist is None: | ||
return False | ||
if distance <= ref_dist * 1.3: # todo: test & magic number, make CONSTANT | ||
return True | ||
return False | ||
|
||
|
||
def iterative_dfs(adjlist: Dict[int, List[int]], | ||
start: int, | ||
border: int, | ||
) -> List[int]: | ||
""" | ||
A depth first search (DFS) graph traversal algorithm to determine indices that belong to a subgroup of the graph. | ||
The subgroup is being explored from the key atom and will not pass the border atom. | ||
This is an iterative and not a recursive algorithm since Python doesn't have a great support for recursion | ||
since it lacks Tail Recursion Elimination and because there is a limit of recursion stack depth (by default is 1000). | ||
Args: | ||
adjlist (Dict[int, List[int]]): The adjacency list. | ||
start (int): The index of the atom to start the DFS from. | ||
border (int): The index of the atom that is the border of the subgroup. | ||
Returns: | ||
List[int]: The indices of atoms in the subgroup including the border atom. | ||
""" | ||
visited = [border] | ||
stack = deque() | ||
stack.append(start) | ||
while stack: | ||
key = stack.pop() | ||
if key in visited: | ||
continue | ||
visited.append(key) | ||
neighbors = adjlist[key] | ||
for neighbor in neighbors: | ||
if neighbor not in visited: | ||
stack.append(neighbor) | ||
return visited |