Skip to content

Commit 2d162ce

Browse files
committed
TMP2
1 parent 0329d24 commit 2d162ce

File tree

17 files changed

+539
-422
lines changed

17 files changed

+539
-422
lines changed

arc/checks/ts.py

-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
)
2020
from arc.imports import settings
2121
from arc.species.converter import check_xyz_dict, displace_xyz, xyz_to_dmat
22-
from arc.mapping.engine import get_atom_indices_of_labeled_atoms_in_an_rmg_reaction
2322
from arc.statmech.factory import statmech_factory
2423

2524
if TYPE_CHECKING:

arc/common.py

-48
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,13 @@
1919
import warnings
2020
import yaml
2121
from collections import deque
22-
from itertools import chain
2322
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
2423

2524
import numpy as np
2625
import pandas as pd
2726
import qcelemental as qcel
2827

2928
from arkane.ess import ess_factory, GaussianLog, MolproLog, OrcaLog, QChemLog, TeraChemLog
30-
import rmgpy
3129
from rmgpy.exceptions import AtomTypeError, ILPSolutionError, ResonanceError
3230
from rmgpy.molecule.atomtype import ATOMTYPES
3331
from rmgpy.molecule.element import get_element
@@ -40,9 +38,7 @@
4038

4139

4240
if TYPE_CHECKING:
43-
from rmgpy.reaction import Reaction
4441
from rmgpy.species import Species
45-
from arc.reaction import ARCReaction
4642

4743

4844
logger = logging.getLogger('arc')
@@ -1563,50 +1559,6 @@ def calc_rmsd(x: Union[list, np.array],
15631559
return float(rmsd)
15641560

15651561

1566-
def _check_r_n_p_symbols_between_rmg_and_arc_rxns(arc_reaction: 'ARCReaction',
1567-
rmg_reactions: List['Reaction'],
1568-
) -> bool:
1569-
"""
1570-
A helper function to check that atom symbols are in the correct order between an ARC reaction
1571-
and its corresponding RMG reactions generated by the get_rmg_reactions_from_arc_reaction() function.
1572-
Used internally for testing.
1573-
1574-
Args:
1575-
arc_reaction (ARCReaction): The ARCReaction object to inspect.
1576-
rmg_reactions (List['Reaction']): Entries are RMG Reaction objects to inspect.
1577-
Could contain either Species or Molecule object as reactants/products.
1578-
1579-
Returns:
1580-
bool: Whether atom symbols are in the same respective order.
1581-
"""
1582-
result = True
1583-
num_rs, num_ps = len(arc_reaction.r_species), len(arc_reaction.p_species)
1584-
arc_r_symbols = [atom.element.symbol for atom in chain(*tuple(arc_reaction.r_species[i].mol.atoms for i in range(num_rs)))]
1585-
arc_p_symbols = [atom.element.symbol for atom in chain(*tuple(arc_reaction.p_species[i].mol.atoms for i in range(num_ps)))]
1586-
for rmg_reaction in rmg_reactions:
1587-
rmg_r_symbols = [atom.element.symbol
1588-
for atom in chain(*tuple(rmg_reaction.reactants[i].atoms
1589-
if isinstance(rmg_reaction.reactants[i], Molecule)
1590-
else rmg_reaction.reactants[i].molecule[0].atoms
1591-
for i in range(num_rs)))]
1592-
rmg_p_symbols = [atom.element.symbol
1593-
for atom in chain(*tuple(rmg_reaction.products[i].atoms
1594-
if isinstance(rmg_reaction.products[i], Molecule)
1595-
else rmg_reaction.products[i].molecule[0].atoms
1596-
for i in range(num_ps)))]
1597-
if any(symbol_1 != symbol_2 for symbol_1, symbol_2 in zip(arc_r_symbols, rmg_r_symbols)):
1598-
print('\nDifferent element order in reactants between ARC and RMG:') # Don't modify to logging.
1599-
print(arc_r_symbols)
1600-
print(rmg_r_symbols)
1601-
result = False
1602-
if any(symbol_1 != symbol_2 for symbol_1, symbol_2 in zip(arc_p_symbols, rmg_p_symbols)):
1603-
print('\nDifferent element order in products between ARC and RMG:')
1604-
print(arc_p_symbols)
1605-
print(rmg_p_symbols)
1606-
result = False
1607-
return result
1608-
1609-
16101562
def safe_copy_file(source: str,
16111563
destination: str,
16121564
wait: int = 10,

arc/common_test.py

-9
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
import arc.common as common
2222
from arc.exceptions import InputError, SettingsError
2323
from arc.imports import settings
24-
from arc.mapping.engine import get_rmg_reactions_from_arc_reaction
2524
import arc.species.converter as converter
26-
from arc.reaction import ARCReaction
2725
from arc.species.species import ARCSpecies
2826

2927

@@ -1268,13 +1266,6 @@ def test_sort_atoms_in_descending_label_order(self):
12681266
mol.atoms[0].label = "a"
12691267
self.assertIsNone(common.sort_atoms_in_descending_label_order(mol=mol))
12701268

1271-
def test_check_r_n_p_symbols_between_rmg_and_arc_rxns(self):
1272-
"""Test the _check_r_n_p_symbols_between_rmg_and_arc_rxns() function"""
1273-
arc_rxn = ARCReaction(r_species=[ARCSpecies(label='CH4', smiles='C'), ARCSpecies(label='OH', smiles='[OH]')],
1274-
p_species=[ARCSpecies(label='CH3', smiles='[CH3]'), ARCSpecies(label='H2O', smiles='O')])
1275-
rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction=arc_rxn)
1276-
self.assertTrue(common._check_r_n_p_symbols_between_rmg_and_arc_rxns(arc_rxn, rmg_reactions))
1277-
12781269
def test_almost_equal_coords(self):
12791270
"""Test the almost_equal_coords() function"""
12801271
with self.assertRaises(TypeError):

arc/family/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
import arc.family.family
2+
from arc.family.family import ReactionFamily
3+
from arc.family.family import get_reaction_family_products

arc/family/family.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from arc.common import clean_text, generate_resonance_structures, get_logger
1313
from arc.imports import settings
14-
from arc.species.converter import check_isomorphism
1514

1615
if TYPE_CHECKING:
1716
from arc.species import ARCSpecies
@@ -519,7 +518,7 @@ def check_product_isomorphism(products: List['Molecule'],
519518
Returns:
520519
bool: Whether the products are isomorphic to the species.
521520
"""
522-
prods_a = [generate_resonance_structures(mol) or [mol] for mol in products]
521+
prods_a = [generate_resonance_structures(mol.copy(deep=True)) or [mol.copy(deep=True)] for mol in products]
523522
prods_b = [spc.mol_list or [spc.mol] for spc in p_species]
524523
if len(prods_a) == 1:
525524
prod_a = prods_a[0]
@@ -564,6 +563,7 @@ def get_all_families(rmg_family_set: Union[List[str], str] = 'default',
564563
) -> List[str]:
565564
"""
566565
Get all available RMG and ARC families.
566+
If ``rmg_family_set`` is a list of family labels and does not contain family sets, it will be returned as is.
567567
568568
Args:
569569
rmg_family_set (Union[List[str], str], optional): The RMG family set to use.
@@ -574,9 +574,11 @@ def get_all_families(rmg_family_set: Union[List[str], str] = 'default',
574574
List[str]: A list of all available families.
575575
"""
576576
rmg_family_set = rmg_family_set or 'default'
577+
family_sets = get_rmg_recommended_family_sets()
578+
if isinstance(rmg_family_set, list) and all(fam not in family_sets for fam in rmg_family_set):
579+
return rmg_family_set
577580
rmg_families, arc_families = list(), list()
578581
if consider_rmg_families:
579-
family_sets = get_rmg_recommended_family_sets()
580582
if not isinstance(rmg_families, list) and rmg_family_set not in list(family_sets) + ['all']:
581583
raise ValueError(f'Invalid RMG family set: {rmg_family_set}')
582584
if rmg_family_set == 'all':

arc/family/family_test.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,8 @@ def test_get_all_families(self):
605605
families = get_all_families(consider_rmg_families=False)
606606
self.assertIsInstance(families, list)
607607
self.assertIn('hydrolysis', families)
608+
families = get_all_families(rmg_family_set=['H_Abstraction'])
609+
self.assertEqual(families, ['H_Abstraction'])
608610

609611
def test_get_rmg_recommended_family_sets(self):
610612
"""Test getting RMG recommended family sets"""
@@ -638,12 +640,12 @@ def test_load(self):
638640
self.assertFalse(fam_2.own_reverse)
639641
self.assertEqual(fam_2.reactants, [['Root']])
640642
self.assertEqual(fam_2.product_num, 2)
641-
self.assertEqual(fam_2.entries, {'Root': """1 *3 R!H u0 {2,S} {3,[S,D]}
642-
2 *4 R!H u0 {1,S} {4,[S,D]}
643-
3 *2 R!H u0 {1,[S,D]} {5,[D,T,B]}
644-
4 *5 R!H u0 {2,[S,D]} {6,S}
645-
5 *1 R!H u0 {3,[D,T,B]}
646-
6 *6 H u0 {4,S}"""})
643+
self.assertEqual(fam_2.entries, {'Root': """1 *3 R!H u0 {2,S} {3,[S,D]}
644+
2 *4 R!H u0 {1,S} {4,[S,D]}
645+
3 *2 R!H u0 {1,[S,D]} {5,[D,T,B]}
646+
4 *5 R!H u0 {2,[S,D]} {6,S}
647+
5 *1 R!H u0 {3,[D,T,B]}
648+
6 *6 [H,Li] u0 {4,S}"""})
647649
self.assertEqual(fam_2.actions, [['CHANGE_BOND', '*1', -1, '*2'],
648650
['BREAK_BOND', '*5', 1, '*6'],
649651
['BREAK_BOND', '*3', 1, '*4'],
@@ -960,6 +962,17 @@ def test_get_isomorphic_subgraph(self):
960962
)
961963
self.assertEqual(isomorphic_subgraph, {0: '*3', 4: '*1', 7: '*2'})
962964

965+
# def test_order_species_list(self):
966+
# """Test the order_species_list() function"""
967+
# spc1 = ARCSpecies(label='spc1', smiles='C')
968+
# spc2 = ARCSpecies(label='spc2', smiles='CC')
969+
# ordered_species_list = order_species_list(species_list=[spc2, spc1], reference_species=[spc1, spc2])
970+
# self.assertEqual(ordered_species_list, [spc1, spc2])
971+
# ordered_species_list = order_species_list(species_list=[spc2, spc1], reference_species=[spc2, spc1])
972+
# self.assertEqual(ordered_species_list, [spc2, spc1])
973+
# ordered_species_list = order_species_list(species_list=[spc2.mol, spc1], reference_species=[spc2, spc1.mol])
974+
# self.assertEqual(ordered_species_list, [spc2.mol, spc1])
975+
963976

964977
if __name__ == '__main__':
965978
unittest.main(testRunner=unittest.TextTestRunner(verbosity=2))

0 commit comments

Comments
 (0)