From 823d3a2fbf28bae7c864773e91540409852912a9 Mon Sep 17 00:00:00 2001 From: Alon Grinberg Dana Date: Sun, 15 Dec 2024 09:39:51 +0200 Subject: [PATCH] f engine mods --- arc/mapping/engine.py | 109 +++++++++++++++++++++++++++++------------- 1 file changed, 77 insertions(+), 32 deletions(-) diff --git a/arc/mapping/engine.py b/arc/mapping/engine.py index 1ded4c0cb9..1d8e706728 100644 --- a/arc/mapping/engine.py +++ b/arc/mapping/engine.py @@ -19,6 +19,7 @@ from arc.common import convert_list_index_0_to_1, extremum_list, generate_resonance_structures, logger, key_by_val from arc.exceptions import SpeciesError +from arc.family import ReactionFamily from arc.species import ARCSpecies from arc.species.conformers import determine_chirality from arc.species.converter import compare_confs, sort_xyz_using_indices, translate_xyz, xyz_from_data, xyz_to_str @@ -874,27 +875,6 @@ def make_bond_changes(rxn: 'ARCReaction', r_cut.mol.update() -def assign_labels_to_products(rxn: 'ARCReaction', - p_label_dict: dict): - """ - Add the indices to the reactants and products. - - Args: - rxn: ARCReaction object to be mapped - p_label_dict: the labels of the products - Consider changing in rmgpy. - - Returns: - Adding labels to the atoms of the reactants and products, to be identified later. - """ - atom_index = 0 - for product_ in rxn.p_species: - for atom in product_.mol.atoms: - if atom_index in p_label_dict.values() and (atom.label is str or atom.label is None): - atom.label = key_by_val(p_label_dict,atom_index) - atom_index+=1 - - def update_xyz(species: List[ARCSpecies]) -> List[ARCSpecies]: """ A helper function, updates the xyz values of each species after cutting. This is important, since the @@ -1037,6 +1017,7 @@ def determine_bdes_on_spc_based_on_atom_labels(spc: "ARCSpecies", bde: Tuple[int def cut_species_based_on_atom_indices(species: List["ARCSpecies"], bdes: List[Tuple[int, int]], + ref_species: Optional[List["ARCSpecies"]] = None, ) -> Optional[List["ARCSpecies"]]: """ A function for scissoring species based on their atom indices. @@ -1045,10 +1026,13 @@ def cut_species_based_on_atom_indices(species: List["ARCSpecies"], species (List[ARCSpecies]): The species list that requires scission. bdes (List[Tuple[int, int]]): A list of the atoms between which the bond should be scissored. The atoms are described using the atom labels, and not the actual atom positions. + ref_species (Optional[List[ARCSpecies]]): A reference species list for which BDE indices are given. Returns: Optional[List["ARCSpecies"]]: The species list input after the scission. """ + if ref_species is not None: + bdes = translate_indices_based_on_ref_species(species, ref_species, bdes) if not bdes: return species for bde in bdes: @@ -1074,6 +1058,62 @@ def cut_species_based_on_atom_indices(species: List["ARCSpecies"], return species +def translate_indices_based_on_ref_species(species: List["ARCSpecies"], + ref_species: List["ARCSpecies"], + bdes: List[Tuple[int, int]], + ) -> Optional[List[Tuple[int, int]]]: + """ + A function for translating the atom indices based on a reference species list. + The given bde indices refer to ``ref_species``, and they'll be translated to refer to ``species``. + + Args: + species (List[ARCSpecies]): The species list for which the indices should be translated. + ref_species (List[ARCSpecies]): The reference species list. + bdes (List[Tuple[int, int]]): The BDE indices to be translated. + + Returns: + Optional[List[Tuple[int, int]]]: The translated BDE indices. + """ + visited_ref_species = list() + species_map = dict() # maps ref species j to species i + index_map = dict() # keys are ref species j indices, values are atom maps between ref species j and species i + for i, spc in enumerate(species): + for j, ref_spc in enumerate(ref_species): + if j not in visited_ref_species and spc.is_isomorphic(ref_spc): + visited_ref_species.append(j) + species_map[j] = i + index_map[j] = map_two_species(ref_spc, spc) + break + new_bdes = list() + ref_spcs_lengths = [ref_spc.number_of_atoms for ref_spc in ref_species] + accum_sum_ref_spcs_lengths = [sum(ref_spcs_lengths[:i+1]) for i in range(len(ref_spcs_lengths))] + spcs_lengths = [spc.number_of_atoms for spc in species] + accum_sum_spcs_lengths = [sum(spcs_lengths[:i+1]) for i in range(len(spcs_lengths))] + for bde in bdes: + a, b = bde + translated_bde = list() + for n in [a, b]: + found = False + for j, ref_len in enumerate(accum_sum_ref_spcs_lengths): + if n < ref_len: + atom_map = index_map[j] + i = species_map[j] + if atom_map is None or i is None: + return None + increment = accum_sum_spcs_lengths[i - 1] if i > 0 else 0 + translated_atom = atom_map[n - accum_sum_ref_spcs_lengths[j]] + increment + translated_bde.append(translated_atom) + found = True + break + if not found: + return None + if len(translated_bde) == 2: + new_bdes.append(tuple(translated_bde)) + else: + return None + return new_bdes + + def copy_species_list_for_mapping(species: List["ARCSpecies"]) -> List["ARCSpecies"]: """ A helper function for copying the species list for mapping. Also keeps the atom indices when copying. @@ -1089,21 +1129,26 @@ def copy_species_list_for_mapping(species: List["ARCSpecies"]) -> List["ARCSpeci return copies -def find_all_bdes(rxn: "ARCReaction", label_dict: dict, is_reactants: bool) -> List[Tuple[int, int]]: +def find_all_breaking_bonds(rxn: "ARCReaction", + label_dict: Dict[str, int], + r_direction: bool, + ) -> Optional[List[Tuple[int, int]]]: """ - A function for finding all the broken(/formed) bonds during a chemical reaction, based on the atom indices. + A function for finding all the broken (or formed of the direction to consider starts with the products) + bonds during a chemical reaction, based on marked atom labels. Args: rxn (ARCReaction): The reaction in question. - label_dict (dict): A dictionary of the atom indices to the atom labels. - is_reactants (bool): Whether the species list represents reactants or products. + label_dict (Dict[str, int]): Keys are atom labels (e.g., '*1'), values are atom indices (0-indexed). + r_direction (bool): Whether to consider the reactants direction (``True``) or the products direction (``False``). Returns: - List[Tuple[int, int]]: A list of tuples of the form (atom_index1, atom_index2) for each broken bond. - Note that these represent the atom indices to be cut, and not final BDEs. + List[Tuple[int, int]]: Entries are tuples of the form (atom_index1, atom_index2) for each broken bond (1-indexed), + representing the atom indices to be cut. """ - bdes = list() - for action in rxn.family.forward_recipe.actions: - if action[0].lower() == ("break_bond" if is_reactants else "form_bond"): - bdes.append((label_dict[action[1]] + 1, label_dict[action[3]] + 1)) - return bdes + family = ReactionFamily(label=rxn.family) + breaking_bonds = list() + for action in family.actions: + if action[0].lower() == ("break_bond" if r_direction else "form_bond"): + breaking_bonds.append((label_dict[action[1]], label_dict[action[3]])) + return breaking_bonds