forked from MarcusOlivecrona/REINVENT
-
Notifications
You must be signed in to change notification settings - Fork 1
/
sample_smiles.py
75 lines (54 loc) · 1.99 KB
/
sample_smiles.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
import torch
from model import RNN
from data_structs import Vocabulary, Experience
from utils import seq_to_smiles
from rdkit.Chem import MolFromSmiles
from rdkit import rdBase
rdBase.DisableLog('rdApp.error')
def generate_smiles(n_smiles=500, restore_from="data/Prior.ckpt", voc_file="data/Voc", embedding_size=128):
"""
This function takes a checkpoint for a trained RNN and the vocabulary file and generates n_smiles new smiles strings.
"""
n = 32
n_smiles = n_smiles - n_smiles%n
print("Generating %i smiles" % n_smiles)
voc = Vocabulary(init_from_file=voc_file)
generator = RNN(voc, embedding_size)
if torch.cuda.is_available():
generator.rnn.load_state_dict(torch.load(restore_from))
else:
generator.rnn.load_state_dict(torch.load(restore_from, map_location=lambda storage, loc: storage))
all_smiles = []
for i in range(int(n_smiles/n)):
sequences, _, _ = generator.sample(n)
smiles = seq_to_smiles(sequences, voc)
all_smiles += smiles
# Freeing up memory
del generator
torch.cuda.empty_cache()
return all_smiles
def check_unique_valid(smiles):
"""
Gives the percentage of unique smiles string and what percentage of the unique strings are valid. It also returns a list of the unique and valid smiles
"""
n_tot = len(smiles)
smiles = list(set(smiles))
n_unique = len(smiles)
valid_smiles = []
for smile in smiles:
mol = MolFromSmiles(smile)
if not isinstance(mol, type(None)):
valid_smiles.append(smile)
n_valid = len(valid_smiles)
perc_unique = n_unique/n_tot *100
perc_valid = n_valid/n_unique *100
return valid_smiles, perc_unique, perc_valid
def write_smiles(smiles, filename="smiles.smi"):
"""
This writes some smiles strings to a file.
"""
f_out = open(filename, "w")
for i,smile in enumerate(smiles):
f_out.write(smile + "\t" + str(i) + "\n")
f_out.close()