Skip to content

Commit

Permalink
Refactor hydrogen atom selection logic in get_h_abs_atoms function fo…
Browse files Browse the repository at this point in the history
…r improved clarity and efficiency
  • Loading branch information
calvinp0 committed Dec 28, 2024
1 parent d485401 commit 2b5259e
Showing 1 changed file with 76 additions and 68 deletions.
144 changes: 76 additions & 68 deletions arc/job/adapters/ts/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from rmgpy.species import Species

from arc.common import almost_equal_coords, get_logger, is_angle_linear, key_by_val
from arc.imports import settings
from arc.imports import settings, submit_scripts
from arc.job.adapter import JobAdapter
from arc.job.adapters.common import _initialize_adapter, ts_adapters_by_rmg_family
from arc.job.factory import register_job_adapter
Expand Down Expand Up @@ -1141,7 +1141,7 @@ def h_abstraction(
# Don't modify dihedrals for an attacking H (or other linear radical) at a linear angle, C ~ A -- H1 - H2 -- H.
dihedral_increment = 360

for rmg_reaction in rmg_reactions:
for i, rmg_reaction in enumerate(rmg_reactions):
rmg_reactant_mol = rmg_reaction.reactants[int(reactants_reversed)].molecule[0]
rmg_product_mol = rmg_reaction.products[int(not products_reversed)].molecule[0]
h1 = rmg_reactant_mol.atoms.index(
Expand Down Expand Up @@ -1179,7 +1179,7 @@ def h_abstraction(
d2_d3_product = [(None, d3) for d3 in d3_values]
else:
d2_d3_product = [(None, None)]

rmg_xyz_guesses = list()
zmats = list()
for d2, d3 in d2_d3_product:
xyz_guess = None
Expand Down Expand Up @@ -1217,10 +1217,10 @@ def h_abstraction(
# This TS is unique, and has no atom collisions.
zmats.append(zmat_guess)
xyz_guesses.append(xyz_guess)

if xyz_guesses:
for i, xyz_guess_crest in enumerate(xyz_guesses):
# 1. Check if dict
rmg_xyz_guesses.append(xyz_guess)
crest_paths = list()
if rmg_xyz_guesses and HAS_CREST:
xyz_guess_crest = rmg_xyz_guesses[0]
if isinstance(xyz_guess_crest, dict):
df_dmat = convert_xyz_to_df(xyz_guess_crest)
elif isinstance(xyz_guess_crest, str):
Expand Down Expand Up @@ -1328,13 +1328,8 @@ def crest_ts_conformer_search(
]
command = " ".join(commands)

if process.returncode == 0:
print("Command completed successfully.")
with open(os.path.join(path, 'crest_best.xyz'), 'r') as f:
content = f.read()

xyz_guess = str_to_xyz(content)
return xyz_guess
if CREST_ENV_PATH:
activation_line = CREST_ENV_PATH
else:
activation_line = ""

Expand Down Expand Up @@ -1510,72 +1505,85 @@ def get_h_abs_atoms(dataframe: pd.DataFrame) -> dict:
Returns:
dict: The hydrogen atom and the two heavy atoms. The keys are 'H', 'A', 'B'
"""
# Ensure there are at least 3 atoms in the TS
if len(dataframe) < 3:
raise ValueError("TS must contain at least 3 atoms.")
if len(dataframe) == 3 and dataframe.index.str.startswith("H").sum() == 2:
# Identify the heavy atom
heavy_atom = dataframe.index[~dataframe.index.str.startswith("H")][0] # Should be the only heavy atom
hydrogen_atoms = dataframe.index[dataframe.index.str.startswith("H")] # List of hydrogen atoms

# Get distances from the heavy atom to the two hydrogens
distances_to_hydrogens = dataframe.loc[heavy_atom, hydrogen_atoms]

# Select the hydrogen with the smallest distance to the heavy atom as `H`
hydrogen_with_min_distance = distances_to_hydrogens.idxmin()

# The other hydrogen becomes `B`
other_hydrogen = hydrogen_atoms[hydrogen_atoms != hydrogen_with_min_distance][0]
closest_atoms = {}
for index, row in dataframe.iterrows():

return {"H": hydrogen_with_min_distance, "A": heavy_atom, "B": other_hydrogen}
row[index] = np.inf
closest = row.nsmallest(2).index.tolist()
closest_atoms[index] = closest

elif len(dataframe) == 4 and dataframe.index.str.startswith("H").sum() == 3:
# Identify the heavy atom
heavy_atom = dataframe.index[~dataframe.index.str.startswith("H")][0] # Should be the only heavy atom
hydrogen_atoms = dataframe.index[dataframe.index.str.startswith("H")] # List of hydrogen atoms
hydrogen_keys = [key for key in dataframe.index if key.startswith("H")]
condition_occurrences = []

# Remove hydrogens from columns and the heavy atom from rows
filtered_df = dataframe.loc[hydrogen_atoms, [heavy_atom]]

# Sort the distances from the heavy atom to all hydrogens
sorted_distances = filtered_df[heavy_atom].sort_values()

# Select the hydrogen with the second furthest distance
hydrogen_with_max_distance = sorted_distances.index[-2]

# Reset the DataFrame back to the original to find the other hydrogen (`B`)
remaining_hydrogens = hydrogen_atoms[hydrogen_atoms != hydrogen_with_max_distance]
filtered_hydrogens_df = dataframe.loc[[hydrogen_with_max_distance], remaining_hydrogens]

# Find the hydrogen closest to the selected hydrogen (`H`)
closest_hydrogen = filtered_hydrogens_df.idxmin(axis=1).iloc[0]
for hydrogen_key in hydrogen_keys:
atom_neighbours = closest_atoms[hydrogen_key]
is_heavy_present = any(
atom for atom in closest_atoms if not atom.startswith("H")
)
if_hydrogen_present = any(
atom
for atom in closest_atoms
if atom.startswith("H") and atom != hydrogen_key
)

return {"H": hydrogen_with_max_distance, "A": heavy_atom, "B": closest_hydrogen}
if is_heavy_present and if_hydrogen_present:
# Store the details of this occurrence
condition_occurrences.append(
{"H": hydrogen_key, "A": atom_neighbours[0], "B": atom_neighbours[1]}
)

# Check if the condition was met
if condition_occurrences:
if len(condition_occurrences) > 1:
# Store distances to decide which occurrence to use
occurrence_distances = []
for occurrence in condition_occurrences:
# Calculate the sum of distances to the two heavy atoms
hydrogen_key = f"{occurrence['H']}"
heavy_atoms = [f"{occurrence['A']}", f"{occurrence['B']}"]
try:
distances = dataframe.loc[hydrogen_key, heavy_atoms].sum()
occurrence_distances.append((occurrence, distances))
except KeyError as e:
print(f"Error accessing distances for occurrence {occurrence}: {e}")

# Select the occurrence with the smallest distance
best_occurrence = min(occurrence_distances, key=lambda x: x[1])[0]
return {
"H": extract_digits(best_occurrence["H"]),
"A": extract_digits(best_occurrence["A"]),
"B": extract_digits(best_occurrence["B"]),
}
else:

# Filter the DataFrame for hydrogen rows and non-hydrogen columns
hydrogen_rows = dataframe.index[dataframe.index.str.startswith("H")]
heavy_atom_columns = dataframe.columns[~dataframe.columns.str.startswith("H")]

filtered_df = dataframe.loc[hydrogen_rows, heavy_atom_columns]
# Check the all the hydrogen atoms, and see the closest two heavy atoms and aggregate their distances to determine which Hyodrogen atom has the lowest distance aggregate
min_distance = np.inf
selected_hydrogen = None
selected_heavy_atoms = None

# Find the hydrogen atom with the smallest bond distance to a heavy atom
min_distances = filtered_df.min(axis=1)
min_distances = min_distances[min_distances <= 2.0]
hydrogen_with_min_distance = min_distances.idxmax()
min_distance_column = filtered_df.loc[hydrogen_with_min_distance].idxmin()
for hydrogen_key in hydrogen_keys:
atom_neighbours = closest_atoms[hydrogen_key]
heavy_atoms = [atom for atom in atom_neighbours if not atom.startswith("H")]

# Handle cases with multiple heavy atoms
remaining_columns = dataframe.columns[
~dataframe.columns.isin([hydrogen_with_min_distance, min_distance_column])
]
remaining_df = dataframe.loc[[hydrogen_with_min_distance], remaining_columns]
second_closest_atom = remaining_df.idxmin(axis=1).iloc[0]
if len(heavy_atoms) < 2:
continue

return {"H": hydrogen_with_min_distance, "A": min_distance_column, "B": second_closest_atom}
distances = dataframe.loc[hydrogen_key, heavy_atoms[:2]].sum()
if distances < min_distance:
min_distance = distances
selected_hydrogen = hydrogen_key
selected_heavy_atoms = heavy_atoms

if selected_hydrogen:
return {
"H": extract_digits(selected_hydrogen),
"A": extract_digits(selected_heavy_atoms[0]),
"B": extract_digits(selected_heavy_atoms[1]),
}
else:
raise ValueError("No valid hydrogen atom found.")


register_job_adapter("heuristics", HeuristicsAdapter)

0 comments on commit 2b5259e

Please sign in to comment.