forked from raudipra/CRISPR-TTN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
103 lines (82 loc) · 3.03 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import sys
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.layers.experimental.preprocessing import StringLookup
from model import TwoTowerModel
from metric import PairTripletAccuracy
from triplet import ConstrainedTripletLoss
from data_loader import DataLoader
# A plotting function you can reuse
def plot(history):
# The history object contains results on the training and test
# sets for each epoch
loss = history.history['loss']
val_loss = history.history['val_loss']
# Get the number of epochs
epochs = range(len(loss))
_ = plt.figure()
plt.title('Training and validation loss')
plt.plot(epochs, loss, color='blue', label='Train')
plt.plot(epochs, val_loss, color='orange', label='Val')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
@tf.function
def one_hot(sequence):
sequence = string_lookup(sequence)
C = tf.constant(5)
one_hot_matrix = tf.one_hot(
sequence,
C,
on_value=1.0,
off_value=0.0,
axis =-1
)
return one_hot_matrix[:, 1:]
def preprocess(feature, label):
RNA_seq = one_hot(feature[0])
sgRNA_seq = one_hot(feature[1])
return tf.concat([RNA_seq, sgRNA_seq], 0), label
if __name__ == "__main__":
if len(sys.argv) < 4:
print("usage: python predict.py [input_path] [model_path] [output_path]")
sys.exit()
input_path = sys.argv[1]
model_path = sys.argv[2]
output_path = sys.argv[2]
data_loader = DataLoader(input_path, training_ratio=0.7)
raw_train_ds, raw_val_ds = data_loader.load()
# Why N? for one encoding purpose, last character = [0, 0, 0, 0]
VOCAB = ["A", "G", "T", "N"]
string_lookup = StringLookup(vocabulary=VOCAB)
AUTOTUNE = tf.data.experimental.AUTOTUNE
BATCH_SIZE = 256
SHUFFLE_SIZE = 1000
encoded_train_ds = raw_train_ds.cache().shuffle(SHUFFLE_SIZE)
encoded_train_ds = encoded_train_ds.prefetch(buffer_size=AUTOTUNE)
encoded_train_ds = encoded_train_ds.map(preprocess)
encoded_val_ds = raw_val_ds.cache().map(preprocess)
train_ds = encoded_train_ds.cache().batch(BATCH_SIZE)
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds = encoded_val_ds.cache().batch(BATCH_SIZE)
model = TwoTowerModel(RNA_length=33, gRNA_length=23)
model.build((None, 56, 4))
print(model.summary())
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=ConstrainedTripletLoss(),
metrics=[PairTripletAccuracy()]
)
stop_early = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10)
save_best = ModelCheckpoint(model_path, monitor='val_custom_accuracy',
mode='max', verbose=1, save_best_only=True)
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=30,
callbacks=[stop_early, save_best]
)
plot(history)