-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathdataset.py
67 lines (57 loc) · 2.73 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from torch.utils.data import Dataset, DataLoader
import random
import numpy as np
CHARISOSMISET = {"#": 29, "%": 30, ")": 31, "(": 1, "+": 32, "-": 33, "/": 34, ".": 2,
"1": 35, "0": 3, "3": 36, "2": 4, "5": 37, "4": 5, "7": 38, "6": 6,
"9": 39, "8": 7, "=": 40, "A": 41, "@": 8, "C": 42, "B": 9, "E": 43,
"D": 10, "G": 44, "F": 11, "I": 45, "H": 12, "K": 46, "M": 47, "L": 13,
"O": 48, "N": 14, "P": 15, "S": 49, "R": 16, "U": 50, "T": 17, "W": 51,
"V": 18, "Y": 52, "[": 53, "Z": 19, "]": 54, "\\": 20, "a": 55, "c": 56,
"b": 21, "e": 57, "d": 22, "g": 58, "f": 23, "i": 59, "h": 24, "m": 60,
"l": 25, "o": 61, "n": 26, "s": 62, "r": 27, "u": 63, "t": 28, "y": 64}
CHARISOSMILEN = 64
CHARPROTSET = {"A": 1, "C": 2, "B": 3, "E": 4, "D": 5, "G": 6,
"F": 7, "I": 8, "H": 9, "K": 10, "M": 11, "L": 12,
"O": 13, "N": 14, "Q": 15, "P": 16, "S": 17, "R": 18,
"U": 19, "T": 20, "W": 21, "V": 22, "Y": 23, "X": 24, "Z": 25}
CHARPROTLEN = 25
def label_smiles(line, smi_ch_ind, MAX_SMI_LEN=100):
X = np.zeros(MAX_SMI_LEN,dtype=np.int64())
for i, ch in enumerate(line[:MAX_SMI_LEN]):
X[i] = smi_ch_ind[ch]
return X
def label_sequence(line, smi_ch_ind, MAX_SEQ_LEN=1000):
X = np.zeros(MAX_SEQ_LEN,np.int64())
for i, ch in enumerate(line[:MAX_SEQ_LEN]):
X[i] = smi_ch_ind[ch]
return X
class CustomDataSet(Dataset):
def __init__(self, pairs):
self.pairs = pairs
def __getitem__(self, item):
return self.pairs[item]
def __len__(self):
return len(self.pairs)
def collate_fn(batch_data):
N = len(batch_data)
drug_ids, protein_ids = [],[]
compound_max = 100
protein_max = 1000
compound_new = torch.zeros((N, compound_max),dtype=torch.long)
protein_new = torch.zeros((N, protein_max),dtype=torch.long)
labels_new = torch.zeros(N, dtype=torch.long)
for i,pair in enumerate(batch_data):
# _, _, compoundstr, proteinstr, label = pair.strip().split()
pair = pair.strip().split()
drug_id,protein_id, compoundstr, proteinstr, label = pair[-5], pair[-4],pair[-3], pair[-2], pair[-1]
drug_ids.append(drug_id)
protein_ids.append(protein_id)
compoundint = torch.from_numpy(label_smiles(compoundstr, CHARISOSMISET,compound_max))
compound_new[i] = compoundint
proteinint = torch.from_numpy(label_sequence(proteinstr, CHARPROTSET,protein_max))
protein_new[i] = proteinint
label = float(label)
labels_new[i] = np.int(label)
# return (drug_ids,protein_ids, compound_new, protein_new, labels_new)
return (compound_new, protein_new, labels_new)