Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add DiGress workflow interface tutorial #1057

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
557 changes: 557 additions & 0 deletions openfl-tutorials/experimental/DiGress/Workflow_Interface_DiGress.ipynb

Large diffs are not rendered by default.

Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
# Copyright (c) 2012-2022 Clement Vignac, Igor Krawczuk, Antoine Siraudin
# source: https://github.com/cvignac/DiGress/

import numpy as np
import torch
import re
# import wandb
try:
from rdkit import Chem
print("Found rdkit, all good")
except ModuleNotFoundError as e:
use_rdkit = False
from warnings import warn
warn("Didn't find rdkit, this will fail")
assert use_rdkit, "Didn't find rdkit"

from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

allowed_bonds = {'H': 1, 'C': 4, 'N': 3, 'O': 2, 'F': 1, 'B': 3, 'Al': 3, 'Si': 4, 'P': [3, 5],
'S': 4, 'Cl': 1, 'As': 3, 'Br': 1, 'I': 1, 'Hg': [1, 2], 'Bi': [3, 5], 'Se': [2, 4, 6]}
bond_dict = [None, Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE,
Chem.rdchem.BondType.AROMATIC]
ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1}


class BasicMolecularMetrics(object):
def __init__(self, dataset_info, train_smiles=None):
self.atom_decoder = dataset_info.atom_decoder
self.dataset_info = dataset_info

# Retrieve dataset smiles only for qm9 currently.
self.dataset_smiles_list = train_smiles

def compute_validity(self, generated):
""" generated: list of couples (positions, atom_types)"""
valid = []
num_components = []
all_smiles = []
for graph in generated:
atom_types, edge_types = graph
mol = build_molecule(atom_types, edge_types, self.dataset_info.atom_decoder)
smiles = mol2smiles(mol)
try:
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True)
num_components.append(len(mol_frags))
except:
pass
if smiles is not None:
try:
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True)
largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
smiles = mol2smiles(largest_mol)
valid.append(smiles)
all_smiles.append(smiles)
except Chem.rdchem.AtomValenceException:
print("Valence error in GetmolFrags")
all_smiles.append(None)
except Chem.rdchem.KekulizeException:
print("Can't kekulize molecule")
all_smiles.append(None)
else:
all_smiles.append(None)

return valid, len(valid) / len(generated), np.array(num_components), all_smiles

def compute_uniqueness(self, valid):
""" valid: list of SMILES strings."""
return list(set(valid)), len(set(valid)) / len(valid)

def compute_novelty(self, unique):
num_novel = 0
novel = []
if self.dataset_smiles_list is None:
print("Dataset smiles is None, novelty computation skipped")
return 1, 1
for smiles in unique:
if smiles not in self.dataset_smiles_list:
novel.append(smiles)
num_novel += 1
return novel, num_novel / len(unique)

def compute_relaxed_validity(self, generated):
valid = []
for graph in generated:
atom_types, edge_types = graph
mol = build_molecule_with_partial_charges(atom_types, edge_types, self.dataset_info.atom_decoder)
smiles = mol2smiles(mol)
if smiles is not None:
try:
mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=True)
largest_mol = max(mol_frags, default=mol, key=lambda m: m.GetNumAtoms())
smiles = mol2smiles(largest_mol)
valid.append(smiles)
except Chem.rdchem.AtomValenceException:
print("Valence error in GetmolFrags")
except Chem.rdchem.KekulizeException:
print("Can't kekulize molecule")
return valid, len(valid) / len(generated)

def evaluate(self, generated):
""" generated: list of pairs (positions: n x 3, atom_types: n [int])
the positions and atom types should already be masked. """
valid, validity, num_components, all_smiles = self.compute_validity(generated)
nc_mu = num_components.mean() if len(num_components) > 0 else 0
nc_min = num_components.min() if len(num_components) > 0 else 0
nc_max = num_components.max() if len(num_components) > 0 else 0
# print(f"Validity over {len(generated)} molecules: {validity * 100 :.2f}%")
# print(f"Number of connected components of {len(generated)} molecules: min:{nc_min:.2f} mean:{nc_mu:.2f} max:{nc_max:.2f}")

relaxed_valid, relaxed_validity = self.compute_relaxed_validity(generated)
# print(f"Relaxed validity over {len(generated)} molecules: {relaxed_validity * 100 :.2f}%")
if relaxed_validity > 0:
unique, uniqueness = self.compute_uniqueness(relaxed_valid)
# print(f"Uniqueness over {len(relaxed_valid)} valid molecules: {uniqueness * 100 :.2f}%")

if self.dataset_smiles_list is not None:
_, novelty = self.compute_novelty(unique)
# print(f"Novelty over {len(unique)} unique valid molecules: {novelty * 100 :.2f}%")
else:
novelty = -1.0
else:
novelty = -1.0
uniqueness = 0.0
unique = []
return ([validity, relaxed_validity, uniqueness, novelty], unique,
dict(nc_min=nc_min, nc_max=nc_max, nc_mu=nc_mu), all_smiles)


def mol2smiles(mol):
try:
Chem.SanitizeMol(mol)
except ValueError:
return None
return Chem.MolToSmiles(mol)


def build_molecule(atom_types, edge_types, atom_decoder, verbose=False):
if verbose:
print("building new molecule")

mol = Chem.RWMol()
for atom in atom_types:
a = Chem.Atom(atom_decoder[atom.item()])
mol.AddAtom(a)
if verbose:
print("Atom added: ", atom.item(), atom_decoder[atom.item()])

edge_types = torch.triu(edge_types)
all_bonds = torch.nonzero(edge_types)
for i, bond in enumerate(all_bonds):
if bond[0].item() != bond[1].item():
mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()])
if verbose:
print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(),
bond_dict[edge_types[bond[0], bond[1]].item()] )
return mol


def build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder, verbose=False):
if verbose:
print("\nbuilding new molecule")

mol = Chem.RWMol()
for atom in atom_types:
a = Chem.Atom(atom_decoder[atom.item()])
mol.AddAtom(a)
if verbose:
print("Atom added: ", atom.item(), atom_decoder[atom.item()])
edge_types = torch.triu(edge_types)
all_bonds = torch.nonzero(edge_types)

for i, bond in enumerate(all_bonds):
if bond[0].item() != bond[1].item():
mol.AddBond(bond[0].item(), bond[1].item(), bond_dict[edge_types[bond[0], bond[1]].item()])
if verbose:
print("bond added:", bond[0].item(), bond[1].item(), edge_types[bond[0], bond[1]].item(),
bond_dict[edge_types[bond[0], bond[1]].item()])
# add formal charge to atom: e.g. [O+], [N+], [S+]
# not support [O-], [N-], [S-], [NH+] etc.
flag, atomid_valence = check_valency(mol)
if verbose:
print("flag, valence", flag, atomid_valence)
if flag:
continue
else:
assert len(atomid_valence) == 2
idx = atomid_valence[0]
v = atomid_valence[1]
an = mol.GetAtomWithIdx(idx).GetAtomicNum()
if verbose:
print("atomic num of atom with a large valence", an)
if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1:
mol.GetAtomWithIdx(idx).SetFormalCharge(1)
# print("Formal charge added")
return mol


# Functions from GDSS
def check_valency(mol):
try:
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
return True, None
except ValueError as e:
e = str(e)
p = e.find('#')
e_sub = e[p:]
atomid_valence = list(map(int, re.findall(r'\d+', e_sub)))
return False, atomid_valence


def correct_mol(m):
# xsm = Chem.MolToSmiles(x, isomericSmiles=True)
mol = m

#####
no_correct = False
flag, _ = check_valency(mol)
if flag:
no_correct = True

while True:
flag, atomid_valence = check_valency(mol)
if flag:
break
else:
assert len(atomid_valence) == 2
idx = atomid_valence[0]
v = atomid_valence[1]
queue = []
check_idx = 0
for b in mol.GetAtomWithIdx(idx).GetBonds():
type = int(b.GetBondType())
queue.append((b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx()))
if type == 12:
check_idx += 1
queue.sort(key=lambda tup: tup[1], reverse=True)

if queue[-1][1] == 12:
return None, no_correct
elif len(queue) > 0:
start = queue[check_idx][2]
end = queue[check_idx][3]
t = queue[check_idx][1] - 1
mol.RemoveBond(start, end)
if t >= 1:
mol.AddBond(start, end, bond_dict[t])
return mol, no_correct


def valid_mol_can_with_seg(m, largest_connected_comp=True):
if m is None:
return None
sm = Chem.MolToSmiles(m, isomericSmiles=True)
if largest_connected_comp and '.' in sm:
vsm = [(s, len(s)) for s in sm.split('.')] # 'C.CC.CCc1ccc(N)cc1CCC=O'.split('.')
vsm.sort(key=lambda tup: tup[1], reverse=True)
mol = Chem.MolFromSmiles(vsm[0][0])
else:
mol = Chem.MolFromSmiles(sm)
return mol


if __name__ == '__main__':
smiles_mol = 'C1CCC1'
print("Smiles mol %s" % smiles_mol)
chem_mol = Chem.MolFromSmiles(smiles_mol)
block_mol = Chem.MolToMolBlock(chem_mol)
print("Block mol:")
print(block_mol)

use_rdkit = True


def check_stability(atom_types, edge_types, dataset_info, debug=False,atom_decoder=None):
if atom_decoder is None:
atom_decoder = dataset_info.atom_decoder

n_bonds = np.zeros(len(atom_types), dtype='int')

for i in range(len(atom_types)):
for j in range(i + 1, len(atom_types)):
n_bonds[i] += abs((edge_types[i, j] + edge_types[j, i])/2)
n_bonds[j] += abs((edge_types[i, j] + edge_types[j, i])/2)
n_stable_bonds = 0
for atom_type, atom_n_bond in zip(atom_types, n_bonds):
possible_bonds = allowed_bonds[atom_decoder[atom_type]]
if type(possible_bonds) == int:
is_stable = possible_bonds == atom_n_bond
else:
is_stable = atom_n_bond in possible_bonds
if not is_stable and debug:
print("Invalid bonds for molecule %s with %d bonds" % (atom_decoder[atom_type], atom_n_bond))
n_stable_bonds += int(is_stable)

molecule_stable = n_stable_bonds == len(atom_types)
return molecule_stable, n_stable_bonds, len(atom_types)


def compute_molecular_metrics(molecule_list, train_smiles, dataset_info):
""" molecule_list: (dict) """

if not dataset_info.remove_h:
print(f'Analyzing molecule stability...')

molecule_stable = 0
nr_stable_bonds = 0
n_atoms = 0
n_molecules = len(molecule_list)

for i, mol in enumerate(molecule_list):
atom_types, edge_types = mol

validity_results = check_stability(atom_types, edge_types, dataset_info)

molecule_stable += int(validity_results[0])
nr_stable_bonds += int(validity_results[1])
n_atoms += int(validity_results[2])

# Validity
fraction_mol_stable = molecule_stable / float(n_molecules)
fraction_atm_stable = nr_stable_bonds / float(n_atoms)
validity_dict = {'mol_stable': fraction_mol_stable, 'atm_stable': fraction_atm_stable}
# if wandb.run:
# wandb.log(validity_dict)
else:
validity_dict = {'mol_stable': -1, 'atm_stable': -1}

metrics = BasicMolecularMetrics(dataset_info, train_smiles)
rdkit_metrics = metrics.evaluate(molecule_list)
all_smiles = rdkit_metrics[-1]
# if wandb.run:
# nc = rdkit_metrics[-2]
# dic = {'Validity': rdkit_metrics[0][0], 'Relaxed Validity': rdkit_metrics[0][1],
# 'Uniqueness': rdkit_metrics[0][2], 'Novelty': rdkit_metrics[0][3],
# 'nc_max': nc['nc_max'], 'nc_mu': nc['nc_mu']}
# wandb.log(dic)

return validity_dict, rdkit_metrics, all_smiles
Loading
Loading