-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
62 lines (44 loc) · 1.76 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import plac
import numpy as np
from keras import backend as K
from keras.utils import to_categorical
from sklearn.model_selection import train_test_split
from music_tagger_cnn import MusicTaggerCNN, SmallCNN, SmallestCNN
from music_tagger_crnn import MusicTaggerCRNN
import data
K.set_image_dim_ordering("th")
def load_data(path):
x = []
y = []
class_names = []
for class_id, subfolder in enumerate([el for el in path.iterdir() if el.is_dir()]):
class_names.append(subfolder.name)
for melgram_path in subfolder.glob('*.npy'):
melgram = np.load(melgram_path)
x.append(melgram)
y.append(class_id)
y = to_categorical(y, len(class_names))
return np.array(x), np.array(y), class_names
def main(net_type, epochs=10):
x, y, class_names = load_data(data.MELGRAM_LOCATION)
print(class_names)
n_classes = len(class_names)
if net_type == 'cnn':
model = MusicTaggerCNN(data.N_FRAMES, data.N_MELS, n_classes)
elif net_type == 'small_cnn':
model = SmallCNN(data.N_FRAMES, data.N_MELS, n_classes)
elif net_type == 'smallest_cnn':
model = SmallestCNN(data.N_FRAMES, data.N_MELS, n_classes)
elif net_type == 'crnn':
model = MusicTaggerCRNN(data.N_FRAMES, data.N_MELS, n_classes)
else:
raise ValueError(net_type)
model.summary()
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
# TODO change batch size
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
model.fit(X_train, y_train, epochs=epochs, validation_data=(X_test, y_test))
model.save('music_{}_epochs:{}.h5'.format(net_type, epochs))
return
if __name__ == '__main__':
plac.call(main)