Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/nmd' into debug_statement
Browse files Browse the repository at this point in the history
  • Loading branch information
calvinp0 committed Dec 3, 2024
2 parents 7b8e561 + 0f27b38 commit f5b373d
Show file tree
Hide file tree
Showing 38 changed files with 13,555 additions and 90,778 deletions.
493 changes: 493 additions & 0 deletions arc/checks/nmd.py

Large diffs are not rendered by default.

779 changes: 779 additions & 0 deletions arc/checks/nmd_test.py

Large diffs are not rendered by default.

108 changes: 17 additions & 91 deletions arc/checks/ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import numpy as np
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import arc.rmgdb as rmgdb
from arc import parser
from arc.checks.nmd import analyze_ts_normal_mode_displacement
from arc.common import (ARC_PATH,
convert_list_index_0_to_1,
extremum_list,
Expand All @@ -19,10 +19,7 @@
sum_list_entries,
)
from arc.imports import settings
from arc.species.converter import check_xyz_dict, displace_xyz, xyz_to_dmat
from arc.mapping.engine import (get_atom_indices_of_labeled_atoms_in_an_rmg_reaction,
get_rmg_reactions_from_arc_reaction,
)
from arc.species.converter import check_xyz_dict, xyz_to_dmat
from arc.statmech.factory import statmech_factory

if TYPE_CHECKING:
Expand Down Expand Up @@ -74,9 +71,8 @@ def check_ts(reaction: 'ARCReaction',
"""
checks = checks or list()
for entry in checks:
if entry not in ['energy', 'freq', 'IRC', 'rotors']:
raise ValueError(f"Requested checks could be 'energy', 'freq', 'IRC', or 'rotors', got:\n{checks}")

if entry not in ['energy', 'NMD', 'IRC', 'rotors']:
raise ValueError(f"Requested checks could be 'energy', 'IRC', 'NMD', or 'rotors', got:\n{checks}")
if 'energy' in checks:
if not reaction.ts_species.ts_checks['E0']:
rxn_copy = compute_rxn_e0(reaction=reaction,
Expand All @@ -91,26 +87,17 @@ def check_ts(reaction: 'ARCReaction',
if reaction.ts_species.ts_checks['E0'] is None and not reaction.ts_species.ts_checks['e_elect']:
check_rxn_e_elect(reaction=reaction, verbose=verbose)

if 'freq' in checks or (not reaction.ts_species.ts_checks.get('NMD', False) and job is not None):
# Check if the job adapter is 'autotst' to decide whether to skip NMD
if 'NMD' in checks and not reaction.ts_species.ts_checks['NMD']:
if job.species[0].chosen_ts_method == "autotst":
logger.info(
f'Skipping normal mode displacement check for TS {reaction.ts_species.label} '
f'due to job adapter "autotst".'
)
reaction.ts_species.ts_checks['NMD'] = True
else:
try:
check_normal_mode_displacement(reaction, job=job)
except (ValueError, KeyError) as e:
logger.error(f'Could not check normal mode displacement, got: \n{e}')
reaction.ts_species.ts_checks['NMD'] = True

# Handle skipping NMD based on the `skip_nmd` flag
if skip_nmd and not reaction.ts_species.ts_checks.get('NMD', False):
logger.warning(
f'Skipping normal mode displacement check for TS {reaction.ts_species.label}.'
)
check_normal_mode_displacement(reaction, job=job)
if skip_nmd and not reaction.ts_species.ts_checks['NMD']:
logger.warning(f'Skipping failed normal mode displacement check for TS {reaction.ts_species.label}')
reaction.ts_species.ts_checks['NMD'] = True

if 'rotors' in checks or (ts_passed_checks(species=reaction.ts_species, exemptions=['E0', 'warnings', 'IRC'])
Expand Down Expand Up @@ -301,7 +288,7 @@ def report_ts_and_wells_energy(r_e: float,

def check_normal_mode_displacement(reaction: 'ARCReaction',
job: Optional['JobAdapter'],
amplitudes: Optional[Union[float, List[float]]] = None,
amplitude: Optional[Union[float, list]] = None,
):
"""
Check the normal mode displacement by identifying bonds that break and form
Expand All @@ -310,75 +297,14 @@ def check_normal_mode_displacement(reaction: 'ARCReaction',
Args:
reaction (ARCReaction): The reaction for which the TS is checked.
job (JobAdapter): The frequency job object instance.
amplitudes (Union[float, List[float]], optional): The factor(s) multiplication for the displacement.
"""
if job is None:
return
if reaction.family is None:
rmgdb.determine_family(reaction)
amplitudes = amplitudes or [0.1, 0.2, 0.4, 0.6, 0.8, 1]
amplitudes = [amplitudes] if isinstance(amplitudes, float) else amplitudes
reaction.ts_species.ts_checks['NMD'] = False
rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction=reaction) or list()
freqs, normal_modes_disp = parser.parse_normal_mode_displacement(path=job.local_path_to_output_file, raise_error=False)
if not len(normal_modes_disp):
return
largest_neg_freq_idx = get_index_of_abs_largest_neg_freq(freqs)
bond_lone_hs = any(len(spc.mol.atoms) == 2 and spc.mol.atoms[0].element.symbol == 'H'
and spc.mol.atoms[0].element.symbol == 'H' for spc in reaction.r_species + reaction.p_species)
# bond_lone_hs = False
xyz = parser.parse_xyz_from_file(job.local_path_to_output_file)
if not xyz['coords']:
xyz = reaction.ts_species.get_xyz()

done = False
for amplitude in amplitudes:
xyz_1, xyz_2 = displace_xyz(xyz=xyz, displacement=normal_modes_disp[largest_neg_freq_idx], amplitude=amplitude)
dmat_1, dmat_2 = xyz_to_dmat(xyz_1), xyz_to_dmat(xyz_2)
dmat_bonds_1 = get_bonds_from_dmat(dmat=dmat_1,
elements=xyz_1['symbols'],
tolerance=1.5,
bond_lone_hydrogens=bond_lone_hs)
dmat_bonds_2 = get_bonds_from_dmat(dmat=dmat_2,
elements=xyz_2['symbols'],
tolerance=1.5,
bond_lone_hydrogens=bond_lone_hs)
got_expected_changing_bonds = False
for i, rmg_reaction in enumerate(rmg_reactions):
r_label_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=reaction,
rmg_reaction=rmg_reaction)[0]
if r_label_dict is None:
continue
expected_breaking_bonds, expected_forming_bonds = reaction.get_expected_changing_bonds(r_label_dict=r_label_dict)
if expected_breaking_bonds is None or expected_forming_bonds is None:
continue
got_expected_changing_bonds = True
breaking = [determine_changing_bond(bond, dmat_bonds_1, dmat_bonds_2) for bond in expected_breaking_bonds]
forming = [determine_changing_bond(bond, dmat_bonds_1, dmat_bonds_2) for bond in expected_forming_bonds]
if len(breaking) and len(forming) \
and not any(entry is None for entry in breaking) and not any(entry is None for entry in forming) \
and all(entry == breaking[0] for entry in breaking) and all(entry == forming[0] for entry in forming) \
and breaking[0] != forming[0]:
reaction.ts_species.ts_checks['NMD'] = True
done = True
break
if not got_expected_changing_bonds and not reaction.ts_species.ts_checks['NMD']:
reaction.ts_species.ts_checks['warnings'] += 'Could not compare normal mode displacement to expected ' \
'breaking/forming bonds due to a missing RMG template; '
reaction.ts_species.ts_checks['NMD'] = True
break
if not len(rmg_reactions):
# Just check that some bonds break/form, and that this is not a torsional saddle point.
warning = f'Cannot check normal mode displacement for reaction {reaction} since a corresponding ' \
f'RMG template could not be generated'
logger.warning(warning)
reaction.ts_species.ts_checks['warnings'] += warning + '; '
if any(bond not in dmat_bonds_2 for bond in dmat_bonds_1) \
or any(bond not in dmat_bonds_1 for bond in dmat_bonds_2):
reaction.ts_species.ts_checks['NMD'] = True
break
if done:
break
amplitude (Union[float, list]): The amplitude of the normal mode displacement motion to check.
If a list, all possible results are returned.
"""
amplitude = amplitude or 0.25
reaction.ts_species.ts_checks['NMD'] = analyze_ts_normal_mode_displacement(reaction=reaction,
job=job,
amplitude=amplitude,
)


def determine_changing_bond(bond: Tuple[int, ...],
Expand Down
Loading

0 comments on commit f5b373d

Please sign in to comment.