Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Reaction atom_map a list of atom maps #718

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
80 changes: 44 additions & 36 deletions arc/job/adapters/ts/gcn_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,46 +274,53 @@ def execute_gcn(self, exe_type: str = 'incore'):
charge=rxn.charge,
multiplicity=rxn.multiplicity,
)
write_sdf_files(rxn=rxn,
reactant_path=self.reactant_path,
product_path=self.product_path,
)
if exe_type == 'queue':
input_dict = {'reactant_path': self.reactant_path,
'product_path': self.product_path,
'local_path': self.local_path,
'yml_out_path': self.yml_out_path,
'repetitions': self.repetitions,
}
save_yaml_file(path=self.yml_in_path, content=input_dict)
self.legacy_queue_execution()
elif exe_type == 'incore':
for _ in range(self.repetitions):
run_subprocess_locally(direction='F',
reactant_path=self.reactant_path,
product_path=self.product_path,
ts_path=self.ts_fwd_path,
local_path=self.local_path,
ts_species=rxn.ts_species,
)
run_subprocess_locally(direction='R',
reactant_path=self.product_path,
product_path=self.reactant_path,
ts_path=self.ts_rev_path,
local_path=self.local_path,
ts_species=rxn.ts_species,
)
if len(self.reactions) < 5:
successes = len([tsg for tsg in rxn.ts_species.ts_guesses if tsg.success and 'gcn' in tsg.method])
if successes:
logger.info(f'GCN successfully found {successes} TS guesses for {rxn.label}.')
else:
logger.info(f'GCN did not find any successful TS guesses for {rxn.label}.')
for i in range(len(rxn.atom_maps)):
try:
write_sdf_files(rxn=rxn,
reactant_path=self.reactant_path,
product_path=self.product_path,
am_index=i,
)
except IndexError:
logger.warning(f'GCN adapter could not write SDF files for {rxn.label} with atom map {i}.')
continue
if exe_type == 'queue':
input_dict = {'reactant_path': self.reactant_path,
'product_path': self.product_path,
'local_path': self.local_path,
'yml_out_path': self.yml_out_path,
'repetitions': self.repetitions,
}
save_yaml_file(path=self.yml_in_path, content=input_dict)
self.legacy_queue_execution()
elif exe_type == 'incore':
for _ in range(self.repetitions):
run_subprocess_locally(direction='F',
reactant_path=self.reactant_path,
product_path=self.product_path,
ts_path=self.ts_fwd_path,
local_path=self.local_path,
ts_species=rxn.ts_species,
)
run_subprocess_locally(direction='R',
reactant_path=self.product_path,
product_path=self.reactant_path,
ts_path=self.ts_rev_path,
local_path=self.local_path,
ts_species=rxn.ts_species,
)
if len(self.reactions) < 5:
successes = len([tsg for tsg in rxn.ts_species.ts_guesses if tsg.success and 'gcn' in tsg.method])
if successes:
logger.info(f'GCN successfully found {successes} TS guesses for {rxn.label}.')
else:
logger.info(f'GCN did not find any successful TS guesses for {rxn.label}.')


def write_sdf_files(rxn: 'ARCReaction',
reactant_path: str,
product_path: str,
am_index: int = 0,
):
"""
Write reactant and product SDF files using RDKit.
Expand All @@ -322,9 +329,10 @@ def write_sdf_files(rxn: 'ARCReaction',
rxn (ARCReaction): The relevant reaction.
reactant_path (str): The path to the reactant SDF file.
product_path (str): The path to the product SDF file.
am_index (int, optional): The atom map index. Default: 0.
"""
reactant_rdkit_mol = rdkit_conf_from_mol(rxn.r_species[0].mol, rxn.r_species[0].get_xyz())[1]
mapped_product = rxn.get_single_mapped_product_xyz()
mapped_product = rxn.get_single_mapped_product_xyz(am_index)
product_rdkit_mol = rdkit_conf_from_mol(mapped_product.mol, mapped_product.get_xyz())[1]
w = Chem.SDWriter(reactant_path)
w.write(reactant_rdkit_mol)
Expand Down
54 changes: 27 additions & 27 deletions arc/job/adapters/ts/heuristics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,10 @@ def test_heuristics_for_h_abstraction(self):
products=[Species(smiles='[CH3]'), Species(smiles='[H][H]')]))
rxn4.determine_family(rmg_database=self.rmgdb)
self.assertEqual(rxn4.family.label, 'H_Abstraction')
self.assertEqual(rxn4.atom_map[0], 0)
self.assertEqual(rxn4.atom_maps[0][0], 0)
for index in [1, 2, 3, 4]:
self.assertIn(rxn4.atom_map[index], [1, 2, 3, 4, 5])
self.assertIn(rxn4.atom_map[5], [4, 5])
self.assertIn(rxn4.atom_maps[0][index], [1, 2, 3, 4, 5])
self.assertIn(rxn4.atom_maps[0][5], [4, 5])
heuristics_4 = HeuristicsAdapter(job_type='tsg',
reactions=[rxn4],
testing=True,
Expand Down Expand Up @@ -926,13 +926,13 @@ def test_keeping_atom_order_in_ts(self):
p_species=[ARCSpecies(label='C2H5', smiles='C[CH2]', xyz=self.c2h5_xyz),
ARCSpecies(label='CCOOH', smiles='CCOO', xyz=self.ccooh_xyz)])
rxn_1.determine_family(rmg_database=self.rmgdb)
self.assertIn(rxn_1.atom_map[0], [0, 1])
self.assertIn(rxn_1.atom_map[1], [0, 1])
self.assertIn(rxn_1.atom_maps[0][0], [0, 1])
self.assertIn(rxn_1.atom_maps[0][1], [0, 1])
for index in [2, 3, 4, 5, 6, 7]:
self.assertIn(rxn_1.atom_map[index], [2, 3, 4, 5, 6, 16])
self.assertEqual(rxn_1.atom_map[8:12], [7, 8, 9, 10])
self.assertIn(tuple(rxn_1.atom_map[12:15]), itertools.permutations([13, 11, 12]))
self.assertIn(rxn_1.atom_map[15:], [[14, 15], [15, 14]])
self.assertIn(rxn_1.atom_maps[0][index], [2, 3, 4, 5, 6, 16])
self.assertEqual(rxn_1.atom_maps[0][8:12], [7, 8, 9, 10])
self.assertIn(tuple(rxn_1.atom_maps[0][12:15]), itertools.permutations([13, 11, 12]))
self.assertIn(rxn_1.atom_maps[0][15:], [[14, 15], [15, 14]])
heuristics_1 = HeuristicsAdapter(job_type='tsg',
reactions=[rxn_1],
testing=True,
Expand All @@ -952,12 +952,12 @@ def test_keeping_atom_order_in_ts(self):
ARCSpecies(label='C2H5', smiles='C[CH2]', xyz=self.c2h5_xyz)])
rxn_2.determine_family(rmg_database=self.rmgdb)
self.assertEqual(rxn_2.family.label, 'H_Abstraction')
self.assertEqual(rxn_2.atom_map[:2], [11, 10])
self.assertIn(tuple(rxn_2.atom_map[2:5]), itertools.permutations([9, 16, 15]))
self.assertIn(tuple(rxn_2.atom_map[5:8]), itertools.permutations([12, 13, 14]))
self.assertEqual(rxn_2.atom_map[8:12], [0, 1, 2, 3])
self.assertIn(tuple(rxn_2.atom_map[12:15]), itertools.permutations([4, 5, 6]))
self.assertIn(tuple(rxn_2.atom_map[15:]), itertools.permutations([7, 8]))
self.assertEqual(rxn_2.atom_maps[0][:2], [11, 10])
self.assertIn(tuple(rxn_2.atom_maps[0][2:5]), itertools.permutations([9, 16, 15]))
self.assertIn(tuple(rxn_2.atom_maps[0][5:8]), itertools.permutations([12, 13, 14]))
self.assertEqual(rxn_2.atom_maps[0][8:12], [0, 1, 2, 3])
self.assertIn(tuple(rxn_2.atom_maps[0][12:15]), itertools.permutations([4, 5, 6]))
self.assertIn(tuple(rxn_2.atom_maps[0][15:]), itertools.permutations([7, 8]))
heuristics_2 = HeuristicsAdapter(job_type='tsg',
reactions=[rxn_2],
testing=True,
Expand All @@ -976,12 +976,12 @@ def test_keeping_atom_order_in_ts(self):
p_species=[ARCSpecies(label='C2H5', smiles='C[CH2]', xyz=self.c2h5_xyz),
ARCSpecies(label='CCOOH', smiles='CCOO', xyz=self.ccooh_xyz)])
rxn_3.determine_family(rmg_database=self.rmgdb)
self.assertEqual(rxn_3.atom_map[:4], [7, 8, 9, 10])
self.assertIn(tuple(rxn_3.atom_map[4:7]), itertools.permutations([11, 12, 13]))
self.assertIn(tuple(rxn_3.atom_map[7:9]), itertools.permutations([14, 15]))
self.assertEqual(rxn_3.atom_map[9:11], [1, 0])
self.assertIn(tuple(rxn_3.atom_map[11:14]), itertools.permutations([16, 5, 6]))
self.assertIn(tuple(rxn_3.atom_map[14:]), itertools.permutations([3, 4, 2]))
self.assertEqual(rxn_3.atom_maps[0][:4], [7, 8, 9, 10])
self.assertIn(tuple(rxn_3.atom_maps[0][4:7]), itertools.permutations([11, 12, 13]))
self.assertIn(tuple(rxn_3.atom_maps[0][7:9]), itertools.permutations([14, 15]))
self.assertEqual(rxn_3.atom_maps[0][9:11], [1, 0])
self.assertIn(tuple(rxn_3.atom_maps[0][11:14]), itertools.permutations([16, 5, 6]))
self.assertIn(tuple(rxn_3.atom_maps[0][14:]), itertools.permutations([3, 4, 2]))

heuristics_3 = HeuristicsAdapter(job_type='tsg',
reactions=[rxn_3],
Expand All @@ -1001,12 +1001,12 @@ def test_keeping_atom_order_in_ts(self):
p_species=[ARCSpecies(label='CCOOH', smiles='CCOO', xyz=self.ccooh_xyz),
ARCSpecies(label='C2H5', smiles='C[CH2]', xyz=self.c2h5_xyz)])
rxn_4.determine_family(rmg_database=self.rmgdb)
self.assertEqual(rxn_4.atom_map[:4], [0, 1, 2, 3])
self.assertIn(tuple(rxn_4.atom_map[4:7]), itertools.permutations([4, 5, 6]))
self.assertIn(tuple(rxn_4.atom_map[7:9]), itertools.permutations([7, 8]))
self.assertEqual(rxn_4.atom_map[9:11], [11, 10])
self.assertIn(tuple(rxn_4.atom_map[11:14]), itertools.permutations([9, 15, 16]))
self.assertIn(tuple(rxn_4.atom_map[14:]), itertools.permutations([12, 13, 14 ]))
self.assertEqual(rxn_4.atom_maps[0][:4], [0, 1, 2, 3])
self.assertIn(tuple(rxn_4.atom_maps[0][4:7]), itertools.permutations([4, 5, 6]))
self.assertIn(tuple(rxn_4.atom_maps[0][7:9]), itertools.permutations([7, 8]))
self.assertEqual(rxn_4.atom_maps[0][9:11], [11, 10])
self.assertIn(tuple(rxn_4.atom_maps[0][11:14]), itertools.permutations([9, 15, 16]))
self.assertIn(tuple(rxn_4.atom_maps[0][14:]), itertools.permutations([12, 13, 14 ]))
heuristics_4 = HeuristicsAdapter(job_type='tsg',
reactions=[rxn_4],
testing=True,
Expand Down
Loading
Loading