forked from MarcusOlivecrona/REINVENT
-
Notifications
You must be signed in to change notification settings - Fork 1
/
data_structs.py
executable file
·329 lines (293 loc) · 12.1 KB
/
data_structs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
import numpy as np
import random
import re
import pickle
from rdkit import Chem
import sys
import time
import torch
from torch.utils.data import Dataset
from utils import Variable
class Vocabulary(object):
"""A class for handling encoding/decoding from SMILES to an array of indices"""
def __init__(self, init_from_file=None, max_length=140):
self.special_tokens = ['EOS', 'GO']
self.additional_chars = set()
self.chars = self.special_tokens
self.vocab_size = len(self.chars)
self.vocab = dict(zip(self.chars, range(len(self.chars))))
self.reversed_vocab = {v: k for k, v in self.vocab.items()}
self.max_length = max_length
if init_from_file: self.init_from_file(init_from_file)
def encode(self, char_list):
"""Takes a list of characters (eg '[NH]') and encodes to array of indices"""
smiles_matrix = np.zeros(len(char_list), dtype=np.float32)
for i, char in enumerate(char_list):
smiles_matrix[i] = self.vocab[char]
return smiles_matrix
def decode(self, matrix):
"""Takes an array of indices and returns the corresponding SMILES"""
chars = []
for i in matrix:
if i == self.vocab['EOS']: break
chars.append(self.reversed_vocab[i])
smiles = "".join(chars)
smiles = smiles.replace("L", "Cl").replace("R", "Br")
return smiles
def tokenize(self, smiles):
"""Takes a SMILES and return a list of characters/tokens"""
regex = '(\[[^\[\]]{1,6}\])'
smiles = replace_halogen(smiles)
char_list = re.split(regex, smiles)
tokenized = []
for char in char_list:
if char.startswith('['):
tokenized.append(char)
else:
chars = [unit for unit in char]
[tokenized.append(unit) for unit in chars]
tokenized.append('EOS')
return tokenized
def add_characters(self, chars):
"""Adds characters to the vocabulary"""
for char in chars:
self.additional_chars.add(char)
char_list = list(self.additional_chars)
char_list.sort()
self.chars = char_list + self.special_tokens
self.vocab_size = len(self.chars)
self.vocab = dict(zip(self.chars, range(len(self.chars))))
self.reversed_vocab = {v: k for k, v in self.vocab.items()}
def init_from_file(self, file):
"""Takes a file containing \n separated characters to initialize the vocabulary"""
with open(file, 'r') as f:
chars = f.read().split()
self.add_characters(chars)
def __len__(self):
return len(self.chars)
def __str__(self):
return "Vocabulary containing {} tokens: {}".format(len(self), self.chars)
class MolData(Dataset):
"""Custom PyTorch Dataset that takes a file containing SMILES.
Args:
fname : path to a file containing \n separated SMILES.
voc : a Vocabulary instance
Returns:
A custom PyTorch dataset for training the Prior.
"""
def __init__(self, fname, voc):
self.voc = voc
self.smiles = []
with open(fname, 'r') as f:
for line in f:
self.smiles.append(line.split()[0])
def __getitem__(self, i):
mol = self.smiles[i]
tokenized = self.voc.tokenize(mol)
encoded = self.voc.encode(tokenized)
return Variable(encoded)
def __len__(self):
return len(self.smiles)
def __str__(self):
return "Dataset containing {} structures.".format(len(self))
@classmethod
def collate_fn(cls, arr):
"""Function to take a list of encoded sequences and turn them into a batch"""
max_length = max([seq.size(0) for seq in arr])
collated_arr = Variable(torch.zeros(len(arr), max_length))
for i, seq in enumerate(arr):
collated_arr[i, :seq.size(0)] = seq
return collated_arr
class Experience(object):
"""Class for prioritized experience replay that remembers the highest scored sequences
seen and samples from them with probabilities relative to their scores."""
def __init__(self, voc, max_size=100):
self.memory = []
self.max_size = max_size
self.voc = voc
def add_experience(self, experience):
"""Experience should be a list of (smiles, score, prior likelihood) tuples"""
self.memory.extend(experience)
if len(self.memory)>self.max_size:
# Remove duplicates
idxs, smiles = [], []
for i, exp in enumerate(self.memory):
if exp[0] not in smiles:
idxs.append(i)
smiles.append(exp[0])
self.memory = [self.memory[idx] for idx in idxs]
# Retain highest scores
self.memory.sort(key = lambda x: x[1], reverse=True)
self.memory = self.memory[:self.max_size]
# Needed because sometimes the printing makes the calculation fail
try:
print("\nBest score in memory: {:.2f}".format(self.memory[0][1]))
except TypeError:
print("\nBest score in memory: %s" % str(self.memory[0][1]))
def sample(self, n):
"""Sample a batch size n of experience"""
if len(self.memory)<n:
raise IndexError('Size of memory ({}) is less than requested sample ({})'.format(len(self), n))
else:
scores = [x[1] for x in self.memory]
sample = np.random.choice(len(self), size=n, replace=False, p=scores/np.sum(scores))
sample = [self.memory[i] for i in sample]
smiles = [x[0] for x in sample]
scores = [x[1] for x in sample]
prior_likelihood = [x[2] for x in sample]
tokenized = [self.voc.tokenize(smile) for smile in smiles]
encoded = [Variable(self.voc.encode(tokenized_i)) for tokenized_i in tokenized]
encoded = MolData.collate_fn(encoded)
return encoded, np.array(scores), np.array(prior_likelihood)
def initiate_from_file(self, fname, scoring_function, Prior):
"""Adds experience from a file with SMILES
Needs a scoring function and an RNN to score the sequences.
Using this feature means that the learning can be very biased
and is typically advised against."""
with open(fname, 'r') as f:
smiles = []
for line in f:
smile = line.split()[0]
if Chem.MolFromSmiles(smile):
smiles.append(smile)
scores = scoring_function(smiles)
tokenized = [self.voc.tokenize(smile) for smile in smiles]
encoded = [Variable(self.voc.encode(tokenized_i)) for tokenized_i in tokenized]
encoded = MolData.collate_fn(encoded)
prior_likelihood, _ = Prior.likelihood(encoded.long())
prior_likelihood = prior_likelihood.data.cpu().numpy()
new_experience = zip(smiles, scores, prior_likelihood)
self.add_experience(new_experience)
def print_memory(self, path):
"""Prints the memory."""
print("\n" + "*" * 80 + "\n")
print(" Best recorded SMILES: \n")
print("Score Prior log P SMILES\n")
with open(path, 'w') as f:
f.write("SMILES Score PriorLogP\n")
for i, exp in enumerate(self.memory[:100]):
if i < 50:
print("{:4.2f} {:6.2f} {}".format(exp[1], exp[2], exp[0]))
f.write("{} {:4.2f} {:6.2f}\n".format(*exp))
print("\n" + "*" * 80 + "\n")
def __len__(self):
return len(self.memory)
def replace_halogen(string):
"""Regex to replace Br and Cl with single letters"""
br = re.compile('Br')
cl = re.compile('Cl')
string = br.sub('R', string)
string = cl.sub('L', string)
return string
def tokenize(smiles):
"""Takes a SMILES string and returns a list of tokens.
This will swap 'Cl' and 'Br' to 'L' and 'R' and treat
'[xx]' as one token."""
regex = '(\[[^\[\]]{1,6}\])'
smiles = replace_halogen(smiles)
char_list = re.split(regex, smiles)
tokenized = []
for char in char_list:
if char.startswith('['):
tokenized.append(char)
else:
chars = [unit for unit in char]
[tokenized.append(unit) for unit in chars]
tokenized.append('EOS')
return tokenized
def canonicalize_smiles_from_file(fname):
"""Reads a SMILES file and returns a list of RDKIT SMILES"""
with open(fname, 'r') as f:
smiles_list = []
for i, line in enumerate(f):
if i % 100000 == 0:
print("{} lines processed.".format(i))
smiles = line.split(" ")[0]
mol = Chem.MolFromSmiles(smiles)
if filter_mol(mol):
smiles_list.append(Chem.MolToSmiles(mol))
print("{} SMILES retrieved".format(len(smiles_list)))
return smiles_list
def filter_mol(mol, max_heavy_atoms=50, min_heavy_atoms=10, element_list=[6,7,8,9,16,17,35,53]):
"""Filters molecules on number of heavy atoms and atom types"""
if mol is not None:
num_heavy = min_heavy_atoms<mol.GetNumHeavyAtoms()<max_heavy_atoms
elements = all([atom.GetAtomicNum() in element_list for atom in mol.GetAtoms()])
if num_heavy and elements:
return True
else:
return False
def write_smiles_to_file(smiles_list, fname):
"""Write a list of SMILES to a file."""
with open(fname, 'w') as f:
for smiles in smiles_list:
f.write(smiles + "\n")
def filter_on_chars(smiles_list, chars):
"""Filters SMILES on the characters they contain.
Used to remove SMILES containing very rare/undesirable
characters."""
smiles_list_valid = []
for smiles in smiles_list:
tokenized = tokenize(smiles)
if all([char in chars for char in tokenized][:-1]):
smiles_list_valid.append(smiles)
return smiles_list_valid
def filter_file_on_chars(smiles_fname, voc_fname):
"""Filters a SMILES file using a vocabulary file.
Only SMILES containing nothing but the characters
in the vocabulary will be retained."""
smiles = []
with open(smiles_fname, 'r') as f:
for line in f:
smiles.append(line.split()[0])
print(smiles[:10])
chars = []
with open(voc_fname, 'r') as f:
for line in f:
chars.append(line.split()[0])
print(chars)
valid_smiles = filter_on_chars(smiles, chars)
with open(smiles_fname + "_filtered", 'w') as f:
for smiles in valid_smiles:
f.write(smiles + "\n")
def combine_voc_from_files(fnames, out_name):
"""Combine two vocabularies"""
chars = set()
for fname in fnames:
with open(fname, 'r') as f:
for line in f:
chars.add(line.split()[0])
with open(out_name, 'w') as f:
for char in chars:
f.write(char + "\n")
def construct_vocabulary(smiles_list, file_name='data/Voc'):
"""Returns all the characters present in a SMILES file.
Uses regex to find characters/tokens of the format '[x]'."""
add_chars = set()
for i, smiles in enumerate(smiles_list):
regex = '(\[[^\[\]]{1,6}\])'
smiles = replace_halogen(smiles)
char_list = re.split(regex, smiles)
for char in char_list:
if char.startswith('['):
add_chars.add(char)
else:
chars = [unit for unit in char]
[add_chars.add(unit) for unit in chars]
print("Number of characters: {}".format(len(add_chars)))
with open(file_name, 'w') as f:
for char in add_chars:
f.write(char + "\n")
return add_chars
def get_dataset_name(fname):
""" Gets the name of a data set"""
fname = fname.split("/")[-1]
fname = fname.split(".")[0]
return fname
if __name__ == "__main__":
smiles_file = sys.argv[1]
print("Reading smiles...")
smiles_list = canonicalize_smiles_from_file(smiles_file)
print("Constructing vocabulary...")
voc_chars = construct_vocabulary(smiles_list)
write_smiles_to_file(smiles_list, "data/mols_filtered.smi")