-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathapp.py
65 lines (51 loc) · 1.94 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import pandas as pd
import tensorflow as tf
import deepasr as asr
# get CTCPipeline
def get_config(feature_type: str = 'spectrogram', multi_gpu: bool = False):
# audio feature extractor
features_extractor = asr.features.preprocess(feature_type=feature_type, features_num=161,
samplerate=16000,
winlen=0.02,
winstep=0.025,
winfunc=np.hanning)
# input label encoder
alphabet_en = asr.vocab.Alphabet(lang='en')
# training model
model = asr.model.get_deepasrnetwork1(
input_dim=161,
output_dim=29,
is_mixed_precision=True
)
# model optimizer
optimizer = tf.keras.optimizers.Adam(
lr=1e-4,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-8
)
# output label deocder
decoder = asr.decoder.GreedyDecoder()
# CTCPipeline
pipeline = asr.pipeline.ctc_pipeline.CTCPipeline(
alphabet=alphabet_en, features_extractor=features_extractor, model=model, optimizer=optimizer, decoder=decoder,
sample_rate=16000, mono=True, multi_gpu=multi_gpu
)
return pipeline
def run():
train_data = pd.read_csv('train_data.csv')
pipeline = get_config(feature_type = 'fbank', multi_gpu=False)
# train asr model
history = pipeline.fit(train_dataset=train_data, batch_size=128, epochs=500)
# history = pipeline.fit_generator(train_dataset = train_data, batch_size=32, epochs=500)
pipeline.save('./checkpoints')
return history
def test_model(test_data):
test_data = pd.read_csv('test_data.csv')
pipeline = asr.pipeline.load('checkpoints')
print("Truth:", test_data['transcripts'].to_list()[0])
print("Prediction", pipeline.predict(test_data['path'].to_list()[0]))
if __name__ == "__main__":
run()
# test_model(test)