-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
56 lines (48 loc) · 1.84 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' #To suppress tensorflow warnings
import numpy as np
import tensorflow as tf
import tensorflow_io as tfio
import train_speech_id_model
def main():
threshold = 0.83
audio_files = [x for x in os.listdir() if x.endswith('mp3')]
audio_embeddings = []
print(f'Comparing {audio_files}')
if os.path.isfile('speech-id-model-110/saved_model.pb'):
model = tf.keras.models.load_model('speech-id-model-110')
else:
model = train_speech_id_model.BaseSpeechEmbeddingModel()
model.load_weights('speech-id-model-110/cp-0110.ckpt')
model.save('speech-id-model-110')
target_rate = 48000
for file in audio_files:
cur_data = tfio.audio.AudioIOTensor(file)
print(f'Processing {file} with sample rate of {cur_data.rate}')
audio_data = cur_data.to_tensor()[:, 0]
if cur_data.rate != target_rate:
print(f'Sampling rate is not {target_rate}. Resampling...')
audio_data = tfio.audio.resample(
audio_data,
tf.cast(cur_data.rate, tf.int64),
tf.cast(target_rate, tf.int64),
)
# set batch size to 1, extract first element
cur_emb = model.predict(
tf.expand_dims(audio_data, axis=0)
)[0]
audio_embeddings.append(cur_emb)
for p in range(len(audio_files)):
for q in range(p + 1, len(audio_files)):
f1 = audio_files[p]
f2 = audio_files[q]
distance = np.linalg.norm(
audio_embeddings[p] - audio_embeddings[q]
)
if distance < threshold:
conclusion = 'Same person:'
else:
conclusion = 'Different people:'
print(f'{f1} and {f2}: {conclusion} {distance}')
if __name__ == "__main__":
main()