diff --git a/arc/ts_split/main_test.py b/arc/ts_split/main_test.py new file mode 100644 index 0000000000..3b1f79ebcf --- /dev/null +++ b/arc/ts_split/main_test.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + + +import unittest + +from arc.imports import settings +import arc.ts_split.split as split +from arc.species.converter import check_xyz_dict, xyz_to_dmat + +servers = settings['servers'] + + +class TestTSSplit(unittest.TestCase): + """ + Contains unit tests for ARC's split module + """ + + def test_bonded(self): + """Test bonded""" + self.assertTrue(split.bonded(1.0, 'C', 'C')) + self.assertTrue(split.bonded(1.0, 'C', 'H')) + self.assertTrue(split.bonded(1.0, 'H', 'C')) + self.assertTrue(split.bonded(1.0, 'H', 'H')) + self.assertFalse(split.bonded(1.7, 'C', 'C')) + self.assertFalse(split.bonded(1.5, 'C', 'H')) + self.assertFalse(split.bonded(1.5, 'H', 'C')) + self.assertFalse(split.bonded(1.5, 'H', 'H')) + + def test_get_adjlist_from_dmat(self): + """Test get_adjlist_from_dmat""" + symbols = ('O', 'N', 'C', 'H', 'H', 'S', 'H') + d = [[0., 1.44678738, 2.1572649, 3.07926623, 2.69780089, 1.74022888, 1.95867823], + [1.44678738, 0., 1.47078693, 2.16662322, 2.12283495, 2.34209263, 1.02337844], + [2.1572649, 1.47078693, 0., 1.09133324, 1.09169397, 1.82651322, 2.02956962], + [3.07926623, 2.16662322, 1.09133324, 0., 1.80097071, 2.51409166, 2.30585633], + [2.69780089, 2.12283495, 1.09169397, 1.80097071, 0., 2.45124337, 2.92889793], + [1.74022888, 2.34209263, 1.82651322, 2.51409166, 2.45124337, 0., 2.68310024], + [1.95867823, 1.02337844, 2.02956962, 2.30585633, 2.92889793, 2.68310024, 0.]] + adjlist = split.get_adjlist_from_dmat(dmat=d, symbols=symbols, h=3, a=2, b=5) # b is incorrect chemically + self.assertEqual(adjlist, {0: [1, 5], 1: [0, 2, 6], 2: [1, 4, 5, 3], 4: [2], 5: [0, 2, 3], 6: [1], 3: [2, 5]}) + + xyz = """ C -3.80799396 1.05904061 0.12143410 + H -3.75776386 0.09672979 -0.34366835 + H -3.24934849 1.76454448 -0.45741718 + H -4.82886508 1.37420677 0.17961125 + C -3.21502590 0.97505234 1.54021348 + H -3.26525696 1.93736874 2.00531040 + H -3.77366533 0.26954471 2.11907272 + C -1.74572494 0.52144864 1.45646938 + H -1.18708880 1.22694232 0.87759054 + H -1.69550074 -0.44087971 0.99139265 + O -0.57307243 0.35560699 4.26172088 + H -1.12770789 0.43395779 2.93512192 + O 0.45489302 1.17807207 4.35811043 + H 1.12427554 0.93029226 3.71613651""" + xyz_dict = check_xyz_dict(xyz) + dmat = xyz_to_dmat(xyz_dict) + adjlist = split.get_adjlist_from_dmat(dmat=dmat, symbols=xyz_dict['symbols'], h=11, a=7, b=10) + self.assertEqual(adjlist, + {0: [1, 2, 3, 4], + 1: [0], + 2: [0], + 3: [0], + 4: [0, 5, 6, 7], + 5: [4], + 6: [4], + 7: [4, 8, 9, 11], + 8: [7], + 9: [7], + 10: [12, 11], + 12: [10, 13], + 13: [12], + 11: [7, 10]}) + + def test_iterative_dfs(self): + """Test iterative_dfs""" + adjlist = {0: [1, 2, 3, 4], 1: [0], 2: [0], 3: [0], 4: [0, 5, 6, 7], 5: [4], 6: [4], 7: [4, 8, 9, 11], 8: [7], + 9: [7], 10: [12, 11], 12: [10, 13], 13: [12], 11: [7, 10]} + g1 = split.iterative_dfs(adjlist=adjlist, start=10, border=11) + self.assertEqual(g1, [11, 10, 12, 13]) + g2 = split.iterative_dfs(adjlist=adjlist, start=7, border=11) + self.assertEqual(g2, [11, 7, 9, 8, 4, 6, 5, 0, 3, 2, 1]) + + + + + +if __name__ == '__main__': + unittest.main(testRunner=unittest.TextTestRunner(verbosity=2)) + + + + + + + + + diff --git a/arc/ts_split/split.py b/arc/ts_split/split.py new file mode 100644 index 0000000000..1e454ecd64 --- /dev/null +++ b/arc/ts_split/split.py @@ -0,0 +1,179 @@ +""" +TS Split +""" + +from typing import Dict, List, Tuple, Union +import numpy as np +from collections import deque + +from arc.common import SINGLE_BOND_LENGTH +from arc.species.converter import xyz_to_dmat, translate_to_center_of_mass + + +MAX_LENGTH = 3.0 + + +def get_group_xyzs_and_key_indices_from_ts(xyz: dict, + a: int, + b: int, + h: int, + ) -> Tuple[dict, dict, dict]: + """ + Get the two corresponding XYZs of groups in an H abstraction TS based on the H atom being abstracted. + + Args: + xyz (dict): The TS xyz. + a (int): The index of one of the heavy atoms H is connected to. + b (int): The index of the other heavy atoms H is connected to. + h (int): The index of the H atom being abstract + + Returns: + Tuple[dict, dict, dict]: + - xyz of group 1 + - xyz of group 2 + - Keys are 'g1_a', 'g1_h', 'g2_a', 'g2_h', values are atom indices i the returned group xyzs. + """ + g1, g2 = divide_h_abs_ts_int_groups(xyz, a, b, h) + g1_xyz, g1_map = split_xyz_by_indices(xyz=xyz, indices=g1) + g2_xyz, g2_map = split_xyz_by_indices(xyz=xyz, indices=g2) + index_dict = {'g1_a': g1_map[a], 'g1_h': g1_map[h], 'g2_a': g2_map[b], 'g2_h': g2_map[h]} + return g1_xyz, g2_xyz, index_dict + + +def split_xyz_by_indices(xyz: dict, + indices: List[int], + ) -> Tuple[dict, Dict[int, int]]: + """ + Split an XYZ dictionary by indices. + Also, map the indices in to_map_indices to the new indices in the returned XYZ. + + Args: + xyz (dict): The XYZ dictionary. + indices (List[int]): The indices to split by. + + Returns: + Tuple[dict, dict[int, int]]: + - The split XYZ dictionary. + - The new indices of the atoms in to_map_indices in the returned XYZ. Keys are the original indices. + """ + new_xyz = dict() + new_xyz['symbols'] = tuple(symbol for i, symbol in enumerate(xyz['symbols']) if i in indices) + new_xyz['isotopes'] = tuple(isotope for i, isotope in enumerate(xyz['isotopes']) if i in indices) + new_xyz['coords'] = tuple(coord for i, coord in enumerate(xyz['coords']) if i in indices) + mapped_index = 0 + map_ = dict() + for i in range(len(xyz['symbols'])): + if i in indices: + map_[i] = mapped_index + mapped_index += 1 + new_xyz = translate_to_center_of_mass(new_xyz) + return new_xyz, map_ + + +def divide_h_abs_ts_int_groups(xyz: dict, + a: int, + b: int, + h: int, + ) -> Tuple[List[int], List[int]]: + """ + Divide the atoms in the TS into two groups based on the H atom being abstracted. + Get the indices of the atoms in the two groups, each includes the abstracted H (so R1H, R2H). + """ + dmat = xyz_to_dmat(xyz) + symbols = xyz['symbols'] + adjlist = get_adjlist_from_dmat(dmat, symbols, h, a, b) + g1 = iterative_dfs(adjlist, start=a, border=h) + g2 = iterative_dfs(adjlist, start=b, border=h) + return g1, g2 + + +def get_adjlist_from_dmat(dmat: Union[np.ndarray, list], + symbols: Tuple[str, ...], + h: int, + a: int, + b: int, + ) -> Dict[int, List[int]]: + """ + Get an adjacency list from a DMat. + + Args: + dmat (np.ndarray): The distance matrix. + symbols (Tuple[str]): THe chemical elements. + h (int): The index of the H atom being abstracted (all indices are 0-indexed) + a (int): The index of one of the heavy atoms H is connected to. + b (int): The index of the other heavy atoms H is connected to. + + Returns: + Dict[int, List[int]]: The adjlist. + """ + adjlist = dict() + for atom_1 in range(len(symbols)): + if atom_1 == h: + continue + for atom_2 in range(len(symbols)): + if atom_2 in [h, atom_1]: + continue + if dmat[atom_1][atom_2] <= MAX_LENGTH: + + if bonded(dmat[atom_1][atom_2], symbols[atom_1], symbols[atom_2]): + if atom_1 not in adjlist: + adjlist[atom_1] = list() + adjlist[atom_1].append(atom_2) + adjlist[h] = [a, b] + adjlist[a].append(h) + adjlist[b].append(h) + return adjlist + + +def bonded(distance: float, s1: str, s2: str) -> bool: + """ + Determine whether two atoms are bonded based on their distance and chemical symbols. + + Args: + distance (float): The distance between the atoms. + s1 (str): The chemical symbol of the first atom. + s2 (str): The chemical symbol of the second atom. + + Returns: + bool: Whether the atoms are bonded. + """ + bond_key = f'{s1}_{s2}' + ref_dist = SINGLE_BOND_LENGTH.get(bond_key, None) or SINGLE_BOND_LENGTH.get(f'{s2}_{s1}', None) + if ref_dist is None: + return False + if distance <= ref_dist * 1.3: # todo: test & magic number, make CONSTANT + return True + return False + + +def iterative_dfs(adjlist: Dict[int, List[int]], + start: int, + border: int, + ) -> List[int]: + """ + A depth first search (DFS) graph traversal algorithm to determine indices that belong to a subgroup of the graph. + The subgroup is being explored from the key atom and will not pass the border atom. + This is an iterative and not a recursive algorithm since Python doesn't have a great support for recursion + since it lacks Tail Recursion Elimination and because there is a limit of recursion stack depth (by default is 1000). + + Args: + adjlist (Dict[int, List[int]]): The adjacency list. + start (int): The index of the atom to start the DFS from. + border (int): The index of the atom that is the border of the subgroup. + + Returns: + List[int]: The indices of atoms in the subgroup including the border atom. + """ + visited = [border] + stack = deque() + stack.append(start) + while stack: + key = stack.pop() + if key in visited: + continue + visited.append(key) + neighbors = adjlist[key] + for neighbor in neighbors: + if neighbor not in visited: + stack.append(neighbor) + return visited