Skip to content

Commit

Permalink
TMP
Browse files Browse the repository at this point in the history
  • Loading branch information
alongd committed Dec 12, 2024
1 parent 6fda63e commit 17ffd87
Show file tree
Hide file tree
Showing 2 changed files with 278 additions and 0 deletions.
99 changes: 99 additions & 0 deletions arc/ts_split/main_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#!/usr/bin/env python3
# encoding: utf-8


import unittest

from arc.imports import settings
import arc.ts_split.split as split
from arc.species.converter import check_xyz_dict, xyz_to_dmat

servers = settings['servers']


class TestTSSplit(unittest.TestCase):
"""
Contains unit tests for ARC's split module
"""

def test_bonded(self):
"""Test bonded"""
self.assertTrue(split.bonded(1.0, 'C', 'C'))
self.assertTrue(split.bonded(1.0, 'C', 'H'))
self.assertTrue(split.bonded(1.0, 'H', 'C'))
self.assertTrue(split.bonded(1.0, 'H', 'H'))
self.assertFalse(split.bonded(1.7, 'C', 'C'))
self.assertFalse(split.bonded(1.5, 'C', 'H'))
self.assertFalse(split.bonded(1.5, 'H', 'C'))
self.assertFalse(split.bonded(1.5, 'H', 'H'))

def test_get_adjlist_from_dmat(self):
"""Test get_adjlist_from_dmat"""
symbols = ('O', 'N', 'C', 'H', 'H', 'S', 'H')
d = [[0., 1.44678738, 2.1572649, 3.07926623, 2.69780089, 1.74022888, 1.95867823],
[1.44678738, 0., 1.47078693, 2.16662322, 2.12283495, 2.34209263, 1.02337844],
[2.1572649, 1.47078693, 0., 1.09133324, 1.09169397, 1.82651322, 2.02956962],
[3.07926623, 2.16662322, 1.09133324, 0., 1.80097071, 2.51409166, 2.30585633],
[2.69780089, 2.12283495, 1.09169397, 1.80097071, 0., 2.45124337, 2.92889793],
[1.74022888, 2.34209263, 1.82651322, 2.51409166, 2.45124337, 0., 2.68310024],
[1.95867823, 1.02337844, 2.02956962, 2.30585633, 2.92889793, 2.68310024, 0.]]
adjlist = split.get_adjlist_from_dmat(dmat=d, symbols=symbols, h=3, a=2, b=5) # b is incorrect chemically
self.assertEqual(adjlist, {0: [1, 5], 1: [0, 2, 6], 2: [1, 4, 5, 3], 4: [2], 5: [0, 2, 3], 6: [1], 3: [2, 5]})

xyz = """ C -3.80799396 1.05904061 0.12143410
H -3.75776386 0.09672979 -0.34366835
H -3.24934849 1.76454448 -0.45741718
H -4.82886508 1.37420677 0.17961125
C -3.21502590 0.97505234 1.54021348
H -3.26525696 1.93736874 2.00531040
H -3.77366533 0.26954471 2.11907272
C -1.74572494 0.52144864 1.45646938
H -1.18708880 1.22694232 0.87759054
H -1.69550074 -0.44087971 0.99139265
O -0.57307243 0.35560699 4.26172088
H -1.12770789 0.43395779 2.93512192
O 0.45489302 1.17807207 4.35811043
H 1.12427554 0.93029226 3.71613651"""
xyz_dict = check_xyz_dict(xyz)
dmat = xyz_to_dmat(xyz_dict)
adjlist = split.get_adjlist_from_dmat(dmat=dmat, symbols=xyz_dict['symbols'], h=11, a=7, b=10)
self.assertEqual(adjlist,
{0: [1, 2, 3, 4],
1: [0],
2: [0],
3: [0],
4: [0, 5, 6, 7],
5: [4],
6: [4],
7: [4, 8, 9, 11],
8: [7],
9: [7],
10: [12, 11],
12: [10, 13],
13: [12],
11: [7, 10]})

def test_iterative_dfs(self):
"""Test iterative_dfs"""
adjlist = {0: [1, 2, 3, 4], 1: [0], 2: [0], 3: [0], 4: [0, 5, 6, 7], 5: [4], 6: [4], 7: [4, 8, 9, 11], 8: [7],
9: [7], 10: [12, 11], 12: [10, 13], 13: [12], 11: [7, 10]}
g1 = split.iterative_dfs(adjlist=adjlist, start=10, border=11)
self.assertEqual(g1, [11, 10, 12, 13])
g2 = split.iterative_dfs(adjlist=adjlist, start=7, border=11)
self.assertEqual(g2, [11, 7, 9, 8, 4, 6, 5, 0, 3, 2, 1])





if __name__ == '__main__':
unittest.main(testRunner=unittest.TextTestRunner(verbosity=2))









179 changes: 179 additions & 0 deletions arc/ts_split/split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""
TS Split
"""

from typing import Dict, List, Tuple, Union
import numpy as np
from collections import deque

from arc.common import SINGLE_BOND_LENGTH
from arc.species.converter import xyz_to_dmat, translate_to_center_of_mass


MAX_LENGTH = 3.0


def get_group_xyzs_and_key_indices_from_ts(xyz: dict,
a: int,
b: int,
h: int,
) -> Tuple[dict, dict, dict]:
"""
Get the two corresponding XYZs of groups in an H abstraction TS based on the H atom being abstracted.
Args:
xyz (dict): The TS xyz.
a (int): The index of one of the heavy atoms H is connected to.
b (int): The index of the other heavy atoms H is connected to.
h (int): The index of the H atom being abstract
Returns:
Tuple[dict, dict, dict]:
- xyz of group 1
- xyz of group 2
- Keys are 'g1_a', 'g1_h', 'g2_a', 'g2_h', values are atom indices i the returned group xyzs.
"""
g1, g2 = divide_h_abs_ts_int_groups(xyz, a, b, h)
g1_xyz, g1_map = split_xyz_by_indices(xyz=xyz, indices=g1)
g2_xyz, g2_map = split_xyz_by_indices(xyz=xyz, indices=g2)
index_dict = {'g1_a': g1_map[a], 'g1_h': g1_map[h], 'g2_a': g2_map[b], 'g2_h': g2_map[h]}
return g1_xyz, g2_xyz, index_dict


def split_xyz_by_indices(xyz: dict,
indices: List[int],
) -> Tuple[dict, Dict[int, int]]:
"""
Split an XYZ dictionary by indices.
Also, map the indices in to_map_indices to the new indices in the returned XYZ.
Args:
xyz (dict): The XYZ dictionary.
indices (List[int]): The indices to split by.
Returns:
Tuple[dict, dict[int, int]]:
- The split XYZ dictionary.
- The new indices of the atoms in to_map_indices in the returned XYZ. Keys are the original indices.
"""
new_xyz = dict()
new_xyz['symbols'] = tuple(symbol for i, symbol in enumerate(xyz['symbols']) if i in indices)
new_xyz['isotopes'] = tuple(isotope for i, isotope in enumerate(xyz['isotopes']) if i in indices)
new_xyz['coords'] = tuple(coord for i, coord in enumerate(xyz['coords']) if i in indices)
mapped_index = 0
map_ = dict()
for i in range(len(xyz['symbols'])):
if i in indices:
map_[i] = mapped_index
mapped_index += 1
new_xyz = translate_to_center_of_mass(new_xyz)
return new_xyz, map_


def divide_h_abs_ts_int_groups(xyz: dict,
a: int,
b: int,
h: int,
) -> Tuple[List[int], List[int]]:
"""
Divide the atoms in the TS into two groups based on the H atom being abstracted.
Get the indices of the atoms in the two groups, each includes the abstracted H (so R1H, R2H).
"""
dmat = xyz_to_dmat(xyz)
symbols = xyz['symbols']
adjlist = get_adjlist_from_dmat(dmat, symbols, h, a, b)
g1 = iterative_dfs(adjlist, start=a, border=h)
g2 = iterative_dfs(adjlist, start=b, border=h)
return g1, g2


def get_adjlist_from_dmat(dmat: Union[np.ndarray, list],
symbols: Tuple[str, ...],
h: int,
a: int,
b: int,
) -> Dict[int, List[int]]:
"""
Get an adjacency list from a DMat.
Args:
dmat (np.ndarray): The distance matrix.
symbols (Tuple[str]): THe chemical elements.
h (int): The index of the H atom being abstracted (all indices are 0-indexed)
a (int): The index of one of the heavy atoms H is connected to.
b (int): The index of the other heavy atoms H is connected to.
Returns:
Dict[int, List[int]]: The adjlist.
"""
adjlist = dict()
for atom_1 in range(len(symbols)):
if atom_1 == h:
continue
for atom_2 in range(len(symbols)):
if atom_2 in [h, atom_1]:
continue
if dmat[atom_1][atom_2] <= MAX_LENGTH:

if bonded(dmat[atom_1][atom_2], symbols[atom_1], symbols[atom_2]):
if atom_1 not in adjlist:
adjlist[atom_1] = list()
adjlist[atom_1].append(atom_2)
adjlist[h] = [a, b]
adjlist[a].append(h)
adjlist[b].append(h)
return adjlist


def bonded(distance: float, s1: str, s2: str) -> bool:
"""
Determine whether two atoms are bonded based on their distance and chemical symbols.
Args:
distance (float): The distance between the atoms.
s1 (str): The chemical symbol of the first atom.
s2 (str): The chemical symbol of the second atom.
Returns:
bool: Whether the atoms are bonded.
"""
bond_key = f'{s1}_{s2}'
ref_dist = SINGLE_BOND_LENGTH.get(bond_key, None) or SINGLE_BOND_LENGTH.get(f'{s2}_{s1}', None)
if ref_dist is None:
return False
if distance <= ref_dist * 1.3: # todo: test & magic number, make CONSTANT
return True
return False


def iterative_dfs(adjlist: Dict[int, List[int]],
start: int,
border: int,
) -> List[int]:
"""
A depth first search (DFS) graph traversal algorithm to determine indices that belong to a subgroup of the graph.
The subgroup is being explored from the key atom and will not pass the border atom.
This is an iterative and not a recursive algorithm since Python doesn't have a great support for recursion
since it lacks Tail Recursion Elimination and because there is a limit of recursion stack depth (by default is 1000).
Args:
adjlist (Dict[int, List[int]]): The adjacency list.
start (int): The index of the atom to start the DFS from.
border (int): The index of the atom that is the border of the subgroup.
Returns:
List[int]: The indices of atoms in the subgroup including the border atom.
"""
visited = [border]
stack = deque()
stack.append(start)
while stack:
key = stack.pop()
if key in visited:
continue
visited.append(key)
neighbors = adjlist[key]
for neighbor in neighbors:
if neighbor not in visited:
stack.append(neighbor)
return visited

0 comments on commit 17ffd87

Please sign in to comment.