Skip to content

Commit

Permalink
TMP
Browse files Browse the repository at this point in the history
  • Loading branch information
alongd committed Dec 20, 2024
1 parent 0329d24 commit f112282
Show file tree
Hide file tree
Showing 17 changed files with 535 additions and 416 deletions.
1 change: 0 additions & 1 deletion arc/checks/ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from arc.imports import settings
from arc.species.converter import check_xyz_dict, displace_xyz, xyz_to_dmat
from arc.mapping.engine import get_atom_indices_of_labeled_atoms_in_an_rmg_reaction
from arc.statmech.factory import statmech_factory

if TYPE_CHECKING:
Expand Down
48 changes: 0 additions & 48 deletions arc/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@
import warnings
import yaml
from collections import deque
from itertools import chain
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import qcelemental as qcel

from arkane.ess import ess_factory, GaussianLog, MolproLog, OrcaLog, QChemLog, TeraChemLog
import rmgpy
from rmgpy.exceptions import AtomTypeError, ILPSolutionError, ResonanceError
from rmgpy.molecule.atomtype import ATOMTYPES
from rmgpy.molecule.element import get_element
Expand All @@ -40,9 +38,7 @@


if TYPE_CHECKING:
from rmgpy.reaction import Reaction
from rmgpy.species import Species
from arc.reaction import ARCReaction


logger = logging.getLogger('arc')
Expand Down Expand Up @@ -1563,50 +1559,6 @@ def calc_rmsd(x: Union[list, np.array],
return float(rmsd)


def _check_r_n_p_symbols_between_rmg_and_arc_rxns(arc_reaction: 'ARCReaction',
rmg_reactions: List['Reaction'],
) -> bool:
"""
A helper function to check that atom symbols are in the correct order between an ARC reaction
and its corresponding RMG reactions generated by the get_rmg_reactions_from_arc_reaction() function.
Used internally for testing.
Args:
arc_reaction (ARCReaction): The ARCReaction object to inspect.
rmg_reactions (List['Reaction']): Entries are RMG Reaction objects to inspect.
Could contain either Species or Molecule object as reactants/products.
Returns:
bool: Whether atom symbols are in the same respective order.
"""
result = True
num_rs, num_ps = len(arc_reaction.r_species), len(arc_reaction.p_species)
arc_r_symbols = [atom.element.symbol for atom in chain(*tuple(arc_reaction.r_species[i].mol.atoms for i in range(num_rs)))]
arc_p_symbols = [atom.element.symbol for atom in chain(*tuple(arc_reaction.p_species[i].mol.atoms for i in range(num_ps)))]
for rmg_reaction in rmg_reactions:
rmg_r_symbols = [atom.element.symbol
for atom in chain(*tuple(rmg_reaction.reactants[i].atoms
if isinstance(rmg_reaction.reactants[i], Molecule)
else rmg_reaction.reactants[i].molecule[0].atoms
for i in range(num_rs)))]
rmg_p_symbols = [atom.element.symbol
for atom in chain(*tuple(rmg_reaction.products[i].atoms
if isinstance(rmg_reaction.products[i], Molecule)
else rmg_reaction.products[i].molecule[0].atoms
for i in range(num_ps)))]
if any(symbol_1 != symbol_2 for symbol_1, symbol_2 in zip(arc_r_symbols, rmg_r_symbols)):
print('\nDifferent element order in reactants between ARC and RMG:') # Don't modify to logging.
print(arc_r_symbols)
print(rmg_r_symbols)
result = False
if any(symbol_1 != symbol_2 for symbol_1, symbol_2 in zip(arc_p_symbols, rmg_p_symbols)):
print('\nDifferent element order in products between ARC and RMG:')
print(arc_p_symbols)
print(rmg_p_symbols)
result = False
return result


def safe_copy_file(source: str,
destination: str,
wait: int = 10,
Expand Down
9 changes: 0 additions & 9 deletions arc/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@
import arc.common as common
from arc.exceptions import InputError, SettingsError
from arc.imports import settings
from arc.mapping.engine import get_rmg_reactions_from_arc_reaction
import arc.species.converter as converter
from arc.reaction import ARCReaction
from arc.species.species import ARCSpecies


Expand Down Expand Up @@ -1268,13 +1266,6 @@ def test_sort_atoms_in_descending_label_order(self):
mol.atoms[0].label = "a"
self.assertIsNone(common.sort_atoms_in_descending_label_order(mol=mol))

def test_check_r_n_p_symbols_between_rmg_and_arc_rxns(self):
"""Test the _check_r_n_p_symbols_between_rmg_and_arc_rxns() function"""
arc_rxn = ARCReaction(r_species=[ARCSpecies(label='CH4', smiles='C'), ARCSpecies(label='OH', smiles='[OH]')],
p_species=[ARCSpecies(label='CH3', smiles='[CH3]'), ARCSpecies(label='H2O', smiles='O')])
rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction=arc_rxn)
self.assertTrue(common._check_r_n_p_symbols_between_rmg_and_arc_rxns(arc_rxn, rmg_reactions))

def test_almost_equal_coords(self):
"""Test the almost_equal_coords() function"""
with self.assertRaises(TypeError):
Expand Down
2 changes: 2 additions & 0 deletions arc/family/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
import arc.family.family

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'arc.family.family' is imported with both 'import' and 'import from'.
from arc.family.family import ReactionFamily
from arc.family.family import get_reaction_family_products
14 changes: 11 additions & 3 deletions arc/family/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from arc.common import clean_text, generate_resonance_structures, get_logger
from arc.imports import settings
from arc.species.converter import check_isomorphism

if TYPE_CHECKING:
from arc.species import ARCSpecies
Expand Down Expand Up @@ -404,6 +403,7 @@ def get_reaction_family_products(rxn: 'ARCReaction',
family_labels = get_all_families(rmg_family_set=rmg_family_set,
consider_rmg_families=consider_rmg_families,
consider_arc_families=consider_arc_families)
print(f'\n\n\n\n\nfamily_labels: {family_labels}\n\n\n\n')
product_dicts = list()
for family_label in family_labels:
# Forward:
Expand Down Expand Up @@ -462,6 +462,9 @@ def determine_possible_reaction_products_from_family(rxn: 'ARCReaction',
List[dict]: A list of dictionaries, each containing the family label, the group labels, the products,
and whether the family's template also represents its own reverse.
"""
if family_label == '1,2_NH3_elimination':
print('rxn: ', rxn)
print('debug')
product_dicts = list()
family = ReactionFamily(label=family_label, consider_arc_families=consider_arc_families)
products = family.generate_products(reactants=rxn.get_reactants_and_products(arc=True, return_copies=True)[0])
Expand All @@ -478,6 +481,8 @@ def determine_possible_reaction_products_from_family(rxn: 'ARCReaction',
'own_reverse': family.own_reverse,
'discovered_in_reverse': reverse,
})
if family_label == '1,2_NH3_elimination':
print(f'product_dicts: {product_dicts}')
return product_dicts


Expand Down Expand Up @@ -519,7 +524,7 @@ def check_product_isomorphism(products: List['Molecule'],
Returns:
bool: Whether the products are isomorphic to the species.
"""
prods_a = [generate_resonance_structures(mol) or [mol] for mol in products]
prods_a = [generate_resonance_structures(mol.copy(deep=True)) or [mol.copy(deep=True)] for mol in products]
prods_b = [spc.mol_list or [spc.mol] for spc in p_species]
if len(prods_a) == 1:
prod_a = prods_a[0]
Expand Down Expand Up @@ -564,6 +569,7 @@ def get_all_families(rmg_family_set: Union[List[str], str] = 'default',
) -> List[str]:
"""
Get all available RMG and ARC families.
If ``rmg_family_set`` is a list of family labels and does not contain family sets, it will be returned as is.
Args:
rmg_family_set (Union[List[str], str], optional): The RMG family set to use.
Expand All @@ -574,9 +580,11 @@ def get_all_families(rmg_family_set: Union[List[str], str] = 'default',
List[str]: A list of all available families.
"""
rmg_family_set = rmg_family_set or 'default'
family_sets = get_rmg_recommended_family_sets()
if isinstance(rmg_family_set, list) and all(fam not in family_sets for fam in rmg_family_set):
return rmg_family_set
rmg_families, arc_families = list(), list()
if consider_rmg_families:
family_sets = get_rmg_recommended_family_sets()
if not isinstance(rmg_families, list) and rmg_family_set not in list(family_sets) + ['all']:
raise ValueError(f'Invalid RMG family set: {rmg_family_set}')
if rmg_family_set == 'all':
Expand Down
15 changes: 15 additions & 0 deletions arc/family/family_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,8 @@ def test_get_all_families(self):
families = get_all_families(consider_rmg_families=False)
self.assertIsInstance(families, list)
self.assertIn('hydrolysis', families)
families = get_all_families(rmg_family_set=['H_Abstraction'])
self.assertEqual(families, ['H_Abstraction'])

def test_get_rmg_recommended_family_sets(self):
"""Test getting RMG recommended family sets"""
Expand Down Expand Up @@ -901,13 +903,15 @@ def test_get_entries(self):
"""Test getting entries from a family"""
fam_1 = ReactionFamily('1,3_Insertion_ROR')
groups_as_lines = fam_1.get_groups_file_as_lines()
print(f'groups_as_lines:\n{groups_as_lines}')
entries = get_entries(groups_as_lines=groups_as_lines, entry_labels=['doublebond', 'cco_2H'])
self.assertEqual(entries, {'doublebond': 'OR{Cd_Cdd, Cdd_Cd, Cd_Cd, Sd_Cd, N1dc_N5ddc, N3d_Cd}',
'cco_2H': """1 *1 Cd u0 {2,D} {3,S} {4,S}
2 *2 Cdd u0 {1,D} {5,D}
3 H u0 {1,S}
4 H u0 {1,S}
5 [O2d,S2d] u0 {2,D}"""})
raise

def test_get_isomorphic_subgraph(self):
"""Test getting the isomorphic subgraph"""
Expand Down Expand Up @@ -960,6 +964,17 @@ def test_get_isomorphic_subgraph(self):
)
self.assertEqual(isomorphic_subgraph, {0: '*3', 4: '*1', 7: '*2'})

# def test_order_species_list(self):
# """Test the order_species_list() function"""
# spc1 = ARCSpecies(label='spc1', smiles='C')
# spc2 = ARCSpecies(label='spc2', smiles='CC')
# ordered_species_list = order_species_list(species_list=[spc2, spc1], reference_species=[spc1, spc2])
# self.assertEqual(ordered_species_list, [spc1, spc2])
# ordered_species_list = order_species_list(species_list=[spc2, spc1], reference_species=[spc2, spc1])
# self.assertEqual(ordered_species_list, [spc2, spc1])
# ordered_species_list = order_species_list(species_list=[spc2.mol, spc1], reference_species=[spc2, spc1.mol])
# self.assertEqual(ordered_species_list, [spc2.mol, spc1])

Check notice

Code scanning / CodeQL

Commented-out code Note

This comment appears to contain commented-out code.


if __name__ == '__main__':
unittest.main(testRunner=unittest.TextTestRunner(verbosity=2))
Loading

0 comments on commit f112282

Please sign in to comment.