Skip to content

Commit

Permalink
f engine mods
Browse files Browse the repository at this point in the history
  • Loading branch information
alongd committed Dec 15, 2024
1 parent 3e7f071 commit 823d3a2
Showing 1 changed file with 77 additions and 32 deletions.
109 changes: 77 additions & 32 deletions arc/mapping/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from arc.common import convert_list_index_0_to_1, extremum_list, generate_resonance_structures, logger, key_by_val
from arc.exceptions import SpeciesError
from arc.family import ReactionFamily
from arc.species import ARCSpecies
from arc.species.conformers import determine_chirality
from arc.species.converter import compare_confs, sort_xyz_using_indices, translate_xyz, xyz_from_data, xyz_to_str
Expand Down Expand Up @@ -874,27 +875,6 @@ def make_bond_changes(rxn: 'ARCReaction',
r_cut.mol.update()


def assign_labels_to_products(rxn: 'ARCReaction',
p_label_dict: dict):
"""
Add the indices to the reactants and products.
Args:
rxn: ARCReaction object to be mapped
p_label_dict: the labels of the products
Consider changing in rmgpy.
Returns:
Adding labels to the atoms of the reactants and products, to be identified later.
"""
atom_index = 0
for product_ in rxn.p_species:
for atom in product_.mol.atoms:
if atom_index in p_label_dict.values() and (atom.label is str or atom.label is None):
atom.label = key_by_val(p_label_dict,atom_index)
atom_index+=1


def update_xyz(species: List[ARCSpecies]) -> List[ARCSpecies]:
"""
A helper function, updates the xyz values of each species after cutting. This is important, since the
Expand Down Expand Up @@ -1037,6 +1017,7 @@ def determine_bdes_on_spc_based_on_atom_labels(spc: "ARCSpecies", bde: Tuple[int

def cut_species_based_on_atom_indices(species: List["ARCSpecies"],
bdes: List[Tuple[int, int]],
ref_species: Optional[List["ARCSpecies"]] = None,
) -> Optional[List["ARCSpecies"]]:
"""
A function for scissoring species based on their atom indices.
Expand All @@ -1045,10 +1026,13 @@ def cut_species_based_on_atom_indices(species: List["ARCSpecies"],
species (List[ARCSpecies]): The species list that requires scission.
bdes (List[Tuple[int, int]]): A list of the atoms between which the bond should be scissored.
The atoms are described using the atom labels, and not the actual atom positions.
ref_species (Optional[List[ARCSpecies]]): A reference species list for which BDE indices are given.
Returns:
Optional[List["ARCSpecies"]]: The species list input after the scission.
"""
if ref_species is not None:
bdes = translate_indices_based_on_ref_species(species, ref_species, bdes)
if not bdes:
return species
for bde in bdes:
Expand All @@ -1074,6 +1058,62 @@ def cut_species_based_on_atom_indices(species: List["ARCSpecies"],
return species


def translate_indices_based_on_ref_species(species: List["ARCSpecies"],
ref_species: List["ARCSpecies"],
bdes: List[Tuple[int, int]],
) -> Optional[List[Tuple[int, int]]]:
"""
A function for translating the atom indices based on a reference species list.
The given bde indices refer to ``ref_species``, and they'll be translated to refer to ``species``.
Args:
species (List[ARCSpecies]): The species list for which the indices should be translated.
ref_species (List[ARCSpecies]): The reference species list.
bdes (List[Tuple[int, int]]): The BDE indices to be translated.
Returns:
Optional[List[Tuple[int, int]]]: The translated BDE indices.
"""
visited_ref_species = list()
species_map = dict() # maps ref species j to species i
index_map = dict() # keys are ref species j indices, values are atom maps between ref species j and species i
for i, spc in enumerate(species):
for j, ref_spc in enumerate(ref_species):
if j not in visited_ref_species and spc.is_isomorphic(ref_spc):
visited_ref_species.append(j)
species_map[j] = i
index_map[j] = map_two_species(ref_spc, spc)
break
new_bdes = list()
ref_spcs_lengths = [ref_spc.number_of_atoms for ref_spc in ref_species]
accum_sum_ref_spcs_lengths = [sum(ref_spcs_lengths[:i+1]) for i in range(len(ref_spcs_lengths))]
spcs_lengths = [spc.number_of_atoms for spc in species]
accum_sum_spcs_lengths = [sum(spcs_lengths[:i+1]) for i in range(len(spcs_lengths))]
for bde in bdes:
a, b = bde
translated_bde = list()
for n in [a, b]:
found = False
for j, ref_len in enumerate(accum_sum_ref_spcs_lengths):
if n < ref_len:
atom_map = index_map[j]
i = species_map[j]
if atom_map is None or i is None:
return None
increment = accum_sum_spcs_lengths[i - 1] if i > 0 else 0
translated_atom = atom_map[n - accum_sum_ref_spcs_lengths[j]] + increment
translated_bde.append(translated_atom)
found = True
break
if not found:
return None
if len(translated_bde) == 2:
new_bdes.append(tuple(translated_bde))
else:
return None
return new_bdes


def copy_species_list_for_mapping(species: List["ARCSpecies"]) -> List["ARCSpecies"]:
"""
A helper function for copying the species list for mapping. Also keeps the atom indices when copying.
Expand All @@ -1089,21 +1129,26 @@ def copy_species_list_for_mapping(species: List["ARCSpecies"]) -> List["ARCSpeci
return copies


def find_all_bdes(rxn: "ARCReaction", label_dict: dict, is_reactants: bool) -> List[Tuple[int, int]]:
def find_all_breaking_bonds(rxn: "ARCReaction",
label_dict: Dict[str, int],
r_direction: bool,
) -> Optional[List[Tuple[int, int]]]:
"""
A function for finding all the broken(/formed) bonds during a chemical reaction, based on the atom indices.
A function for finding all the broken (or formed of the direction to consider starts with the products)
bonds during a chemical reaction, based on marked atom labels.
Args:
rxn (ARCReaction): The reaction in question.
label_dict (dict): A dictionary of the atom indices to the atom labels.
is_reactants (bool): Whether the species list represents reactants or products.
label_dict (Dict[str, int]): Keys are atom labels (e.g., '*1'), values are atom indices (0-indexed).
r_direction (bool): Whether to consider the reactants direction (``True``) or the products direction (``False``).
Returns:
List[Tuple[int, int]]: A list of tuples of the form (atom_index1, atom_index2) for each broken bond.
Note that these represent the atom indices to be cut, and not final BDEs.
List[Tuple[int, int]]: Entries are tuples of the form (atom_index1, atom_index2) for each broken bond (1-indexed),
representing the atom indices to be cut.
"""
bdes = list()
for action in rxn.family.forward_recipe.actions:
if action[0].lower() == ("break_bond" if is_reactants else "form_bond"):
bdes.append((label_dict[action[1]] + 1, label_dict[action[3]] + 1))
return bdes
family = ReactionFamily(label=rxn.family)
breaking_bonds = list()
for action in family.actions:
if action[0].lower() == ("break_bond" if r_direction else "form_bond"):
breaking_bonds.append((label_dict[action[1]], label_dict[action[3]]))
return breaking_bonds

0 comments on commit 823d3a2

Please sign in to comment.