diff --git a/arc/mapping/engine_test.py b/arc/mapping/engine_test.py index 29beb4019f..028f8074a8 100644 --- a/arc/mapping/engine_test.py +++ b/arc/mapping/engine_test.py @@ -601,15 +601,18 @@ def test_label_species_atoms(self): self.assertEqual(atom.label,str(index)) index +=1 - def test_cut_species_for_mapping(self): + def test_cut_species_based_on_atom_indices(self): """test the cut_species_for_mapping function""" rxn_1_test = ARCReaction(r_species=[self.r_1, self.r_2], p_species=[self.p_1, self.p_2]) rxn_1_test.determine_family(self.db) - reactants, products, loc_r, loc_p = prepare_reactants_and_products_for_scissors(rxn_1_test, - self.r_label_dict_rxn_1, - self.p_label_dict_rxn_1) - r_cuts = cut_species_for_mapping(reactants, loc_r) - p_cuts = cut_species_for_mapping(products, loc_p) + reactants, products = copy_species_list_for_mapping(rxn_1_test.r_species), copy_species_list_for_mapping(rxn_1_test.p_species) + label_species_atoms(reactants), label_species_atoms(products) + + r_bdes, p_bdes = find_all_bdes(rxn_1_test, self.r_label_dict_rxn_1, True), find_all_bdes(rxn_1_test, self.p_label_dict_rxn_1, False) + + r_cuts = cut_species_based_on_atom_indices(reactants, r_bdes) + p_cuts = cut_species_based_on_atom_indices(products, p_bdes) + self.assertIn("C[CH]C", [r_cut.mol.copy(deep=True).smiles for r_cut in r_cuts]) self.assertIn("[F]", [r_cut.mol.copy(deep=True).smiles for r_cut in r_cuts]) @@ -621,21 +624,24 @@ def test_cut_species_for_mapping(self): spc = ARCSpecies(label="test", smiles="CNC", bdes = [(1, 2), (2, 3)]) for i, a in enumerate(spc.mol.atoms): a.label=str(i) - cuts = cut_species_for_mapping([spc], [2]) + cuts = cut_species_based_on_atom_indices([spc], [(1, 2), (2, 3)]) self.assertEqual(len(cuts), 3) for cut in cuts: self.assertTrue(any([cut.mol.copy(deep=True).is_isomorphic(ARCSpecies(label="1", smiles="[CH3]").mol), cut.mol.copy(deep=True).is_isomorphic(ARCSpecies(label="2", smiles="[NH]").mol)])) - cuts = cut_species_for_mapping([ARCSpecies(label="H2", smiles="[H][H]", bdes=[(1, 2)])], [1]) + h2 = ARCSpecies(label="H2", smiles="[H][H]") + label_species_atoms([h2]) + + cuts = cut_species_based_on_atom_indices([h2], [(1, 2)]) self.assertEqual(len(cuts), 2) for cut in cuts: self.assertEqual(cut.get_xyz()["symbols"], ('H',)) - - def test_multiple_cut_on_species(self): - """test the multiple_cut_on_species function""" - spc = ARCSpecies(label="test", smiles="NCN", bdes = [(1, 2), (2, 3)]) - for i, a in enumerate(spc.mol.atoms): + + spcs = [ARCSpecies(label="r", smiles = 'O=C(O)CCF')] + label_species_atoms(spcs) + cuts = cut_species_based_on_atom_indices(spcs, [(6, 5), (4, 2), (3, 7)]) + self.assertEqual(len(cuts), 4) a.label=str(i) spc.final_xyz = spc.get_xyz() cuts = multiple_cut_on_species(spc, spc.bdes)