-
Notifications
You must be signed in to change notification settings - Fork 0
/
google_asr_sin.py
1369 lines (1179 loc) · 50.4 KB
/
google_asr_sin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Test Google Cloud ASR offerings on the Speech in Noise (SPIN) test."""
import datetime
import json
import re
from typing import List, Dict, Optional, Set, Tuple, Union
from absl import app
from absl import flags
import dataclasses
import matplotlib.pyplot as plt
from scipy import signal
from scipy.io import wavfile
from scipy.optimize import curve_fit
import numpy as np
import os
import fsspec
from google.cloud import speech_v2
from google.api_core.client_options import ClientOptions
from google.cloud.speech_v2 import SpeechClient
from google.cloud.speech_v2.types import cloud_speech
# Import the Google Recognizer
from google.cloud.speech_v1.types.cloud_speech import RecognizeResponse
# Supported Languages:
# https://cloud.google.com/speech-to-text/v2/docs/speech-to-text-supported-languages
# Import the Whisper Recognizer
import whisper, subprocess
from whisper.normalizers import EnglishTextNormalizer
# https://github.com/openai/whisper/discussions/734
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
####### Utilities to describe recognition results ########
@dataclasses.dataclass
class RecogResult:
word: str
start_time: float
end_time: float
################### Speech Recognition Class (easier API) ######################
class RecognitionEngine(object):
def CreateSpeechClient(self, gcp_project, model='default_long'):
pass
def CreateRecognizer(self, with_timings=False, locale: str = 'en-US'):
pass
class GoogleRecognitionEngine(RecognitionEngine):
"""A class that provides a nicer interface to Google's Cloud
text-to-speech API.
Here are some useful links:
https://cloud.google.com/speech-to-text/docs/speech-to-text-supported-languages
"""
def __init__(self):
self._client = None
self._parent = None
def CreateSpeechClient(self,
gcp_project,
model='default_long',
):
"""Acquires the appropriate authentication and creates a Cloud Speech stub.
The model name is needed because we connect to a different server if the
model is 'chirp'.
Returns:
a Cloud Speech stub.
"""
self._model = model
self._project = gcp_project
self._spoken_punct = False
self._auto_punct = False
if model == 'chirp':
chirp_endpoint = 'us-central1-speech.googleapis.com'
client_options = ClientOptions(api_endpoint=chirp_endpoint)
self._location = 'us-central1'
else:
client_options = ClientOptions()
self._location = 'global'
self._client = SpeechClient(client_options=client_options)
def ListModels(self, gcp_project: str):
if self._client is None:
self.CreateSpeechClient(gcp_project)
parent = f'projects/{self._project}/locations/{self._location}'
request = speech_v2.ListRecognizersRequest(parent=parent)
return self._client.ListModels(request)
def ListRecognizers(self, gcp_project: str):
if self._client is None:
self.CreateSpeechClient(gcp_project)
parent = f'projects/{self._project}/locations/{self._location}'
request = speech_v2.ListRecognizersRequest(parent=parent)
# print(f'ListRecognizers request is: {request}')
return self._client.list_recognizers(request)
def CreateRecognizer(self,
with_timings=False,
locale: str = 'en-US',
# gcp_project: str,
# recognizer_id: str,
# debug=False
):
# https://cloud.google.com/speech-to-text/v2/docs/medical-models
if self._model == 'medical_conversation':
self._spoken_punct = False
self._auto_punct = True
elif self._model == 'medical_dictation':
self._spoken_punct = True
self._auto_punct = True
else:
self._spoken_punct = False
self._auto_punct = False
self._recognizer_config = cloud_speech.RecognitionConfig(
auto_decoding_config=cloud_speech.AutoDetectDecodingConfig(),
language_codes=[locale],
model=self._model,
features = speech_v2.RecognitionFeatures(
enable_word_time_offsets = with_timings,
enable_automatic_punctuation = self._auto_punct,
enable_spoken_punctuation = self._spoken_punct,
),
)
def RecognizeFile(self,
audio_file_path: str,
with_timings=False,
debug=False) -> cloud_speech.RecognizeResponse:
"""Recognize the speech from a file.
Returns:
https://cloud.google.com/python/docs/reference/speech/latest/google.cloud.speech_v1.types.RecognizeResponse
Note: Unless the file ends in .wav, the file is read in, and the entire
contents, including the binary header, are passed to the recognizer as a
16kHz audio waveform."""
if audio_file_path.endswith('.wav'):
with fsspec.open(audio_file_path, 'rb') as fp:
audio_fs, audio_data = wavfile.read(fp)
return self.RecognizeWaveform(audio_data, audio_fs,
with_timings=with_timings)
recognizer_name = (f'projects/{self._project}/locations/'
f'{self._location}/recognizers/_')
# Create the request we'd like to send
request = cloud_speech.RecognizeRequest(
recognizer = recognizer_name,
config = self._recognizer_config,
content = self.ReadAudioFile(audio_file_path)
)
# Send the request
if debug:
print(request)
response = self._client.recognize(request)
return response
def RecognizeWaveform(self,
waveform: Union[bytes, np.ndarray],
sample_rate: int = 16000,
with_timings=False,
debug=False) -> RecognizeResponse:
"""Recognize the speech from a waveform."""
if isinstance(waveform, np.ndarray):
waveform = waveform.astype(np.int16).tobytes()
recognizer_name = (f'projects/{self._project}/locations/'
f'{self._location}/recognizers/_')
# Create the request we'd like to send
self._recognizer_config = cloud_speech.RecognitionConfig(
explicit_decoding_config = cloud_speech.ExplicitDecodingConfig(
# Change these based on the encoding of the audio
# See the encoding documentation on how to do this.
# https://cloud.google.com/speech-to-text/v2/docs/encoding
encoding = 'LINEAR16',
sample_rate_hertz = sample_rate,
audio_channel_count = 1,
),
# auto_decoding_config=cloud_speech.AutoDetectDecodingConfig(),
language_codes=['en-US'],
model=self._model,
features = speech_v2.RecognitionFeatures(
enable_word_time_offsets = with_timings,
enable_automatic_punctuation = self._auto_punct,
enable_spoken_punctuation = self._spoken_punct,
),
)
request = cloud_speech.RecognizeRequest(
recognizer = recognizer_name,
config = self._recognizer_config,
content = waveform
)
if debug:
print(request)
# Send the request
response = self._client.recognize(request)
return response
def parse_transcript(self, response:
cloud_speech.RecognizeResponse) -> List[RecogResult]:
"""Parse the results from the Cloud ASR engine and return a simple list
of words and times. This is for the entire (60s) utterance."""
words = []
for a_result in response.results:
try:
# For reasons I don't understand sometimes a results is missing the
# alternatives
l = len(a_result.alternatives) > 0
if not l:
continue
except: # pylint: disable=bare-except
continue
for word in a_result.alternatives[0].words:
# print(f'Processing: {word}')
start_time = parse_time(word.start_offset)
end_time = parse_time(word.end_offset)
recog_result = RecogResult(word.word.lower(), start_time, end_time)
words.append(recog_result)
words.append(RecogResult('.', end_time, end_time))
# print(words[-1])
return words
def ReadAudioFile(self, audio_file_path: str):
# if audio_file_path[0] != '/':
# PREFIX = '/google_src/files/head/depot/'
# audio_file_path = os.path.join(PREFIX, audio_file_path)
with fsspec.open(audio_file_path, 'rb') as audio_file:
audio_data = audio_file.read()
return audio_data
def parse_time(time_proto:datetime.timedelta) -> float:
# return time_proto.seconds + time_proto.nanos/1e9
return time_proto.total_seconds()
def print_all_sentences(results: cloud_speech.RecognizeResponse):
for r in results:
if r.alternatives:
print(r.alternatives[0].transcript)
else:
print('No alternatives')
####### Code to talk to the Whisper Recognizer and reformat its results ########
def test_whisper():
print('Just testing Whisper Engine')
ground_truth = load_ground_truth(FLAGS.ground_truth_cache)
whisper_engine = make_recognizer_engine('whisper.base.en')
spin_file = find_all_spin_files(FLAGS.audio_dir, 'Babble List ')[0]
whisper_dict = whisper_engine.RecognizeFile(spin_file)
word_list = whisper_engine.parse_transcript(whisper_dict)
score_all_tests(spin_snrs, ground_truth[:1], [word_list,], debug=True)
class WhisperRecognitionEngine(RecognitionEngine):
def __init__(self, model_type='small.en'):
self.model = whisper.load_model(model_type)
def RecognizeFile(self, path: str, with_timings=False, debug=False) -> Dict:
"""Recognize the speech from a file.
"""
whisper_result = self.model.transcribe(path)
return whisper_result
def parse_transcript(self, whisper_dict) -> List[RecogResult]:
results = []
for seg in whisper_dict['segments']:
b = float(seg['start'])
e = float(seg['end'])
for w in seg['text'].split(' '):
results.append(RecogResult(w, b, e))
return results
def make_recognizer_engine(model_type: str) -> RecognitionEngine:
if model_type.startswith('whisper'):
return WhisperRecognitionEngine(model_type.replace('whisper.', ''))
return GoogleRecognitionEngine()
####### Utilities to prepare original SPIN waveforms ########
def generate_ffmpeg_cmds():
"""Generate the FFMPEG commands to downsample and rename the
QuickSIN files. The Google drive data from Matt has these files:
* 34 Sep List 11.aif - Stereo utterances: clean sentences on the left,
constant amplitude babble noise on the right
* 34 Sep List 11_sentence.wav - Mono clean sentences
* 34 Sep List 11_babble.wav - Mono babble
* List 11.aif - Mono mixed test sentences, with the SNR stepping
down after each sentence.
"""
for i in range(1, 13):
input_name = f'{23+i} Sep List {i}_sentence.wav'
output_name = f'QuickSIN22/Clean List {i}.wav'
print(f'ffmpeg -i "{input_name}" -ar 22050 "{output_name}"')
input_name = f'List {i}.aif'
output_name = f'QuickSIN22/Babble List {i}.wav'
print(f'ffmpeg -i "{input_name}" -ar 22050 "{output_name}"')
print()
################## Organize SPIN recogntion results #######################
# A list of lists. Each (final) list is a list of recognition results (words
# and times). Then a list of these "sentence" lists.
SpinFileTranscripts = List[List[RecogResult]]
def recognize_all_spin(all_wavs: List[str],
asr_engine: RecognitionEngine,
debug=False) -> SpinFileTranscripts:
"""Recognize some SPiN sentences using the specified ASR engine.
Return a list of the transcription results. Each recognition result is
a list of alternatives, all in RecogResult format. This is used for both
clean and noisy utterances.
Args:
all_wavs: List of SPIN wave file names
asr_engine: A recognition object to do the calculations
debug: Whether to generate debugging messages.
Returns:
A list (for each SPIN list) of lists (for each sentence) of recognition
results.
"""
all_results = []
for f in all_wavs:
if 'Calibration' in f:
continue
pretty_file_name = os.path.basename(f)
if debug:
print('Recognizing', pretty_file_name)
resp = asr_engine.RecognizeFile(f, with_timings=True, debug=debug)
if debug:
print(f'{pretty_file_name}:',)
for result in resp.results:
if result.alternatives:
print(f' {result.alternatives[0].transcript}')
else:
print('. ** Empty ASR Result **')
recog_results = asr_engine.parse_transcript(resp)
all_results.append(recog_results)
return all_results
def find_sentence_boundaries(
spin_truth_names: List[str], # File names with clean speech.
sentence_boundary_graph: str = '') -> Tuple[List[int], np.ndarray]:
"""Figure out the inter-sentence boundaries of each
sentence in all lists. Do this by summing the absolute value of each
waveform, filter this to get an envelope, then look for the
minimums.
Return a list of sample numbers indicating the midpoint between sentences.
"""
# Figure out the maximum length
max_len = 0
for i in range(12):
with fsspec.open(spin_truth_names[i], 'rb') as fp:
audio_fs, audio_data = wavfile.read(fp)
max_len = max(max_len, len(audio_data))
# Now sum the absolute value of each of the 12 waveforms.
all_audio = np.zeros(max_len, float)
for i in range(12):
with fsspec.open(spin_truth_names[i], 'rb') as fp:
_, audio_data = wavfile.read(fp)
all_audio[:len(audio_data)] = (all_audio[:len(audio_data)] +
np.abs(audio_data))
# Now filter this signal to snmooth it.
b, a = signal.butter(4, 0.00005)
envelope = signal.filtfilt(b, a, all_audio, padlen=150)
envelope = signal.filtfilt(b, a, envelope, padlen=150)
def find_min(y, start, stop):
start_sample = int(start)
end_sample = int(stop)
i = np.argmin(y[start_sample:end_sample]) + start_sample
return float(i)
# Look for the minimum in each approximate range.
splits = np.array([0.2, 0.3, 0.5, 0.7, 0.9, 1.1])*1e6 # This is in samples.
breaks = [0]
for i in range(5):
breaks.append(find_min(envelope, splits[i], splits[i+1])/audio_fs)
breaks.append(max_len/audio_fs)
# Plot the results
if sentence_boundary_graph:
plt.clf()
plt.plot(np.arange(len(all_audio))/float(audio_fs), all_audio)
plt.xlabel('Time (s)')
plt.ylabel('Average Audio Level')
plt.title('Sentence Boundaries from Average Amplitude')
current_axis = plt.axis()
for b in breaks:
plt.plot([b, b], current_axis[2:], '--')
plt.savefig(sentence_boundary_graph)
return breaks, all_audio
######################## QuickSIN Ground Truth ################################
# Pages 96-97 of this PhD thesis:
# Suzanne E. Sklaney, Binaural sound field presentation of the QuickSIN:
# Equivalncy across lists and signal-to-noise ratios.
# https://etda.libraries.psu.edu/files/final_submissions/5788
key_word_list = """
L 0 S 0 white silk jacket any shoes
L 0 S 1 child crawled into dense grass
L 0 S 2 Footprints showed path took beach
L 0 S 3 event/vent near edge fresh air
L 0 S 4 band Steel 3/three inches/in wide
L 0 S 5 weight package seen high scale
L 1 S 0 tear/Tara/tera thin sheet yellow pad
L 1 S 1 cruise Waters Sleek yacht fun
L 1 S 2 streak color down left Edge
L 1 S 3 done before boy see it
L 1 S 4 Crouch before jump miss mark
L 1 S 5 square peg settle round hole
L 2 S 0 pitch straw through door stable
L 2 S 1 sink thing which pile dishes
L 2 S 2 post no bills office wall
L 2 S 3 dimes showered down all sides
L 2 S 4 pick card slip under pack/pact
L 2 S 5 store jammed before sale start
L 3 S 0 sense smell better than touch
L 3 S 1 picked up dice second roll
L 3 S 2 drop ashes worn/Warren Old rug
L 3 S 3 couch cover Hall drapes blue
L 3 S 4 stems Tall Glasses cracked broke
L 3 S 5 cleats sank deeply soft turf
L 4 S 0 have better than wait Hope
L 4 S 1 screen before fire kept Sparks
L 4 S 2 thick glasses helped read print
L 4 S 3 chair looked strong no bottom
L 4 S 4 told wild Tales/tails frighten him
L 4 S 5 force equal would move Earth
L 5 S 0 leaf drifts along slow spin
L 5 S 1 pencil cut sharp both ends
L 5 S 2 down road way grain farmer
L 5 S 3 best method fix place clips
L 5 S 4 if Mumble your speech lost
L 5 S 5 toad Frog hard tell apart
L 6 S 0 kite dipped swayed/suede stayed aloft
L 6 S 1 beatle/beetle drowned hot June sun/son
L 6 S 2 theft Pearl pin Kept Secret
L 6 S 3 wide grin earned many friends
L 6 S 4 hurdle pit aid long Pole
L 6 S 5 Peep under tent see Clown
L 7 S 0 sun came light Eastern sky
L 7 S 1 stale smell old beer lingers
L 7 S 2 desk firm on shaky floor
L 7 S 3 list names carved around base
L 7 S 4 news struct/struck out Restless Minds
L 7 S 5 Sand drifts over sill house
L 8 S 0 take shelter tent keep still
L 8 S 1 Little Tales/tails they tell false
L 8 S 2 press pedal with left foot
L 8 S 3 black trunk fell from Landing
L 8 S 4 cheap clothes flashy don't/dont last
L 8 S 5 night alarm roused/roust deep sleep
L 9 S 0 dots light betrayed black cat
L 9 S 1 put chart mantle Tack down
L 9 S 2 steady drip worse drenching rain
L 9 S 3 flat pack less luggage space
L 9 S 4 gloss top made unfit read
L 9 S 5 Seven Seals stamped great sheets
L10 S 0 marsh freeze when cold enough
L10 S 1 gray mare walked before colt
L10 S 2 bottles hold four kinds rum
L10 S 3 wheeled/wheled/wield bike past winding road
L10 S 4 throw used paper cup plate
L10 S 5 wall phone ring loud often
L11 S 0 hinge door creaked old age
L11 S 1 bright lanterns Gay dark lawn
L11 S 2 offered proof form large chart
L11 S 3 their eyelids droop want sleep
L11 S 4 many ways do these things
L11 S 5 we like see clear weather/whether
""".split('\n')
def word_alternatives(words: str,
homonyms_dict: Dict[str, Set[str]]) -> Set[str]:
"""Convert a string with words separated by '/' into a set."""
all_words = words.strip().split('/')
base_word = all_words[0]
if base_word in homonyms_dict:
return set(all_words) | homonyms_dict[base_word]
return set(all_words)
homonyms = """
# Add word equivalances, here. Nominal correct word, from the list above,
# should be listed first. Then alternatives.
# Homonyms
tails/tales
four/4/for
mare/maire
pedal/petal
wheeled/wield
sun/son
marsh/marsue
their/there
white/whitesilk
silk/whitesilk
roll/role
drowned/dround
yacht/yaught/yach
hall/haul
# Close enough words.
# None so far.. we count if an error if even one phoneme is wrong.
"""
def make_homonyms_dictionary(*equivalance_lists: str) -> Dict[str, Set[str]]:
"""Convert a set of speech-recognition equivalances, specified as text, into
a dictionary of sets of equivalent words. Each equivalence is specified as
words on a line, separated by '/'. Lines that start with '#' are ignored so
that the choices can be documented.
"""
result_dict = {}
for equivalance in equivalance_lists:
equivalance = equivalance.split('\n')
equivalance_lines = [line.strip().split('/') for line in equivalance
if line.strip() and line.strip()[0] != '#']
# print(all_sets)
for a_set in equivalance_lines:
# print(f'Processing {a_set}')
w = a_set[0] # The base term
if w in result_dict:
raise ValueError(f'Found duplicate key {w}')
else:
result_dict[w] = set(a_set[1:])
# print(f'After {w} dict is {result_dict}')
return result_dict
def ingest_quicksin_truth(
word_list: str,
homonym_dict: Dict[str, Set[str]]) -> Dict[Tuple[int, int],
List[Set[str]]]:
"""Convert the text from the big string above into a set of key words
(and alternatives) that describe the expected answers from a SPIN test.
For each line (which will be entered into a dictionary keyed by list and
sentence number) create a list of test words, where each test word is stored
as a list of alternatives in a set.
"""
keyword_dict = {}
for line in word_list:
line = line.strip().lower()
if not line: continue
list_number = int(line[1:3])
sentence_number = int(line[5:7])
key_words = line[7:].split(' ')
key_words = [w for w in key_words if w]
key_list = [word_alternatives(w, homonym_dict) for w in key_words]
if len(key_list) != 5:
print(f'Have too many words in L{list_number} S{sentence_number}:',
key_list)
keyword_dict[list_number, sentence_number] = key_list
return keyword_dict
homonym_list = make_homonyms_dictionary(homonyms)
all_keyword_dict = ingest_quicksin_truth(key_word_list, homonym_list)
######## Recognize the SPIN waveforms and calculate all word timings ##########
@dataclasses.dataclass
class SpinSentence:
"""A structure that describes one SPiN sentence, with the transcript,
individual words, the sentence start and end time, and the SNR.
There are six SPiN sentences per list, one per SNR.
"""
sentence_words: List[str]
true_word_list: List[Set[str]] # List of words and their alternatives
# words: list[str]
start_time: float
end_time: float
snr: float # This sentence's test SNR
# Organize the clean speech transcripts. Each 60s wavedform becomes a list of
# recognized sentences. Return a list of list of sentences.
spin_snrs = (25, 20, 15, 10, 5, 0)
def format_quicksin_truth(
spin_transcripts: SpinFileTranscripts, # List of List of RecogResults
sentence_breaks: List[float], # Times in seconds
snr_list: Tuple[float] = spin_snrs) -> List[List[SpinSentence]]:
"""Parse the recognition results and produce a List (of sentences at different
SNRs). Return a list of 12 SPIN lists, each list containing the 6 SPIN
sentences at the different SNRs.
"""
assert len(spin_transcripts) > 0
# assert len(sentence_breaks) == 7 # Nominally 7 except when testing
# assert len(snr_list) == 6 # Nominally 6 except when testing
spin_results = []
# Iterate through the lists (each list contains 6 different sentences)
print('Sentence breaks are at:', sentence_breaks)
for list_number, clean_transcript in enumerate(spin_transcripts):
sentences = []
for snr_number, snr in enumerate(snr_list):
sentence_start_time = float(sentence_breaks[snr_number])
sentence_end_time = float(sentence_breaks[snr_number+1])
sentence_words = [w for w in clean_transcript
if (w.start_time > sentence_start_time and
w.end_time < sentence_end_time)]
assert len(sentence_words) > 0, (f'No words found for list {list_number},'
f' snr #{snr_number} between '
f'{sentence_start_time}s and '
f'{sentence_end_time}s.')
recognized_words = [w.word for w in sentence_words]
sentence = SpinSentence(recognized_words,
all_keyword_dict[list_number, snr_number],
min(*[w.start_time for w in sentence_words]),
max(*[w.start_time for w in sentence_words]),
snr
)
sentences.append(sentence)
spin_results.append(sentences)
return spin_results
def print_spin_ground_truth(truth: List[List[SpinSentence]]):
for spin_i, spin_list in enumerate(truth):
print(f'\nQuickSIN list {spin_i}')
for sentence in spin_list:
print(f'SNR {sentence.snr} '
f'from {sentence.start_time}s to {sentence.end_time}s:',
' '.join(sentence.sentence_words)
)
def save_ground_truth(truth: List[List[SpinSentence]], filename: str):
"""Save the QuickSIN ground truth into a JSON file so we don't have
to compute it again. The ground truth starts as a SpinSentence, but is
saved (and restored) as a dictionary."""
class GoogleSinEncoder(json.JSONEncoder):
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
elif isinstance(o, set):
return list(o)
return super().default(o)
saved_data = {
'ground_truth': truth,
'time': str(datetime.datetime.now()),
}
with fsspec.open(filename, 'w') as fp:
json.dump(saved_data, fp, cls=GoogleSinEncoder)
def load_ground_truth(filename: str) -> List[List[SpinSentence]]:
"""Load the precomputed QuickSIN ground truth from a file.
Args:
The file from which to read the cached ground truth.
Return:
A list of 12 SPIN lists, each list containing the 6 SPIN
sentences at the different SNRs. The data is a dictionary on disk, but
is converted to a SpinSentence by this routine.
"""
with fsspec.open(filename, 'r') as fp:
saved_data = json.load(fp)
if isinstance(saved_data, dict):
truth = saved_data['ground_truth']
else:
truth = saved_data # Old format file
#pylint: disable=inconsistent-quotes
print(f'Reloading ground truth saved at {saved_data["time"]}')
assert isinstance(truth, list)
for i in range(len(truth)): # Nominally 12, except during testing
assert isinstance(truth[i], list)
for s in range(len(truth[i])): # Nominally 6, except during testing
truth[i][s] = SpinSentence(**truth[i][s])
truth[i][s].true_word_list = [set(word_list) for word_list
in truth[i][s].true_word_list]
return truth
def print_quicksin_ground_truth(d: List[List[SpinSentence]]):
print('Ground Truth with all Synonyms:')
for list_number in range(12):
for sentence_number in range(6):
print(f'L{list_number} S{sentence_number}:', end='')
sent = d[list_number][sentence_number]
print(f' [{sent.start_time}-{sent.end_time}] ', end='')
for syn_set in sent.true_word_list:
syn_list = list(syn_set)
print(f'{"/".join(syn_list)} ', end='')
print()
XXnumber_re = re.compile(r' (\d+).wav')
def XXsort_by_list_number(s: str) -> int:
m = number_re.search(s)
assert m, f'Could not find list number in {s}'
return int(m[1])
def find_all_spin_files(audio_dir: str, prefix: str):
"""Return the list of SPIN test files that match the specified prefix.
Return them in order so we can use the ones we care about.
Args:
audio_dir: Where to find the audio files in .wav format
prefix: The file name prefix for the subset we care about.
Returns:
A list of filenames with the given prefix, in numerical order.
"""
spin_file_names = [os.path.join(audio_dir,
f'{prefix}{i+1}.wav') for i in range(12)]
for f in spin_file_names:
assert os.path.exists(f), f'{f} not found in find_all_spin_files'
return spin_file_names
def compute_quicksin_truth(
wav_dir: str,
project_id: str,
sentence_breaks: Optional[List[float]] = None,
snr_list: Tuple[float] = spin_snrs,
sentence_boundary_graph: str = '') -> List[List[SpinSentence]]:
"""Create the ground truth for a SPIN test.
Process all the clean speech files to figure out the start and stop of each
sentence. Combine with the keyword list to create a list (by QuickSin list)
of lists of sentences (one sentence per test SNR).
"""
spin_truth_names = find_all_spin_files(wav_dir, 'Clean List ')
print(f'Found {len(spin_truth_names)} QuickSIN lists to process.')
if sentence_breaks is None:
print('Finding sentence boundaries...')
sentence_breaks, _ = find_sentence_boundaries(spin_truth_names,
sentence_boundary_graph)
print('Sentence breaks are:', sentence_breaks)
model = 'latest_long'
print(f'Transcribing the QuickSIN WAV files with {model} model....')
asr_engine = make_recognizer_engine(model)
asr_engine.CreateSpeechClient(project_id, model)
asr_engine.CreateRecognizer(with_timings=True)
true_transcripts = recognize_all_spin(spin_truth_names, asr_engine)
print('True transcripts are:')
for l in range(len(true_transcripts)):
for s in range(len(true_transcripts[l])):
print(f'List {l}, Sentence {s}:', true_transcripts[l][s])
print('Formatting the QuickSIN Ground Truth....')
spin_ground_truths = format_quicksin_truth(true_transcripts,
sentence_breaks,
snr_list)
return spin_ground_truths
################ SCORE ALL MODELS IN NOISE ############################
def words_in_trial(recognized_words: List[RecogResult],
start_time: float, # Seconds
end_time: float, # Seconds
tolerance: float = 2.0) -> List[str]:
"""Pick out the words in the babble mixture that fall within time window."""
start_time -= tolerance
end_time += tolerance
# print(recognized_words[0].keys())
words = [r.word for r in recognized_words
if r.end_time >= start_time and r.start_time <= end_time]
# Remove all but word characters (not punctuation)
words = [re.sub(r'[^\w]', '', word.lower()) for word in words]
return words
def prettyprint_words_and_alternatives(words_and_alternatives):
results = []
for w in words_and_alternatives:
if isinstance(w, str):
results.append(w)
elif isinstance(w, (list, set)):
results.append('/'.join(list(w)))
else:
raise ValueError(f'Unexpected type in {words_and_alternatives}')
return results
def score_word_list(true_words: List[Set[str]],
recognized_words: List[str], max_count=0) -> int:
"""How many of the key words show up in the transcript?
Args:
true_words: a list of tuples, each tuple is a list of words
and their alternates
recognized_words: A list of recognized words to score.
max_count: Maximum number to return
Returns:
The number of correctly (as judged by the true_words list) that
were recognized.
"""
score = 0
missing_words = []
for words_and_alternates in true_words:
for word in words_and_alternates:
found = False
if word in recognized_words:
found = True
break
if found:
score += 1
else:
missing_words.append(words_and_alternates)
if max_count:
score = min(score, max_count)
if missing_words:
missing_words = prettyprint_words_and_alternatives(missing_words)
missing_words_string = ', '.join(missing_words)
recognized_words_string = ', '.join(recognized_words)
print(f'Could not find {missing_words_string} '
f'in this recognition result: {recognized_words_string}')
return score
def score_all_tests(snrs: List[float],
ground_truths: List[List[SpinSentence]],
reco_results: List[RecogResult],
# Good lists from McArdle2006 and Mead2006.
good_lists: List[int] = [1, 2, 6, 8, 10, 11, 12],
debug=False) -> np.ndarray:
"""Score all list for all SNRs for one recognizer. Iterate through all the
lists, and then score each SNR in the list.
Args:
snrs: List of (standard) test SNRs
ground_truths: A list (one item per QuickSIN List) of lists (one per
SNR) of the ground truth sentences. This routine uses the true_word_list
as the ground truth (this also contains the alternates).
reco_results: A list of recognition results
good_lists: Which QuickSIN lists do we want to use for scoring. This list
starts at 1 (while the Python lists start counting at zero)!!!
Return a np array with the fraction correct for each of the listed SNRs."""
num_lists = len(good_lists)
num_keywords = 5
correct_counts = []
for snr_num, snr in enumerate(snrs):
correct_count = 0
for list_num in good_lists:
list_num -= 1 # Since Python starts lists at 0.
true_words = ground_truths[list_num][snr_num].true_word_list
recognized_words = words_in_trial(
reco_results[list_num],
ground_truths[list_num][snr_num].start_time,
ground_truths[list_num][snr_num].end_time)
correct_this_trial = score_word_list(true_words, recognized_words,
max_count=5)
correct_count += correct_this_trial
if debug and correct_this_trial < 5:
print(f'SNR {snr}:')
print(f' Expected words: {true_words}')
print(f' Recognized words: {recognized_words}')
print(f' Correct count is {correct_this_trial}')
correct_counts.append(correct_count)
correct_frac = np.asarray(correct_counts,
dtype=float) / (num_keywords*num_lists)
return correct_frac
# Models listed here:
# https://cloud.google.com/speech-to-text/docs/transcription-model
all_model_names = {# Google cloud models
'latest_long': 'Google\nLatest\nLong',
# 'latest_short': 'Google\nLatest\nShort',
'telephony': 'Google\nTelephony',
# 'medical_dictation': Google\nMedical\nDictation',
# 'medical_conversation': Google\nMedical\nConversation'
'chirp': 'Google\nUSM/Chirp',
# Whisper models from https://github.com/openai/whisper
'whisper.base.en': 'Whisper\nBase',
'whisper.small.en': 'Whisper\nSmall',
'whisper.medium.en': 'Whisper\nMedium',
'whisper.large': 'Whisper\nLarge',
}
def recognize_with_all_models(
project_id: str,
spin_test_names: List[str],
model_names: Dict[str, str] = all_model_names) -> Dict[str,
SpinFileTranscripts]:
"""Recognize all QuickSIN test files (from the spin_test_names argument)
Return a dictionary of scores vs. list of lists of transcripts, keyed by
the model name.
"""
model_results = {}
for model_name in model_names:
print(f'\nRecogizing with model: {model_name}')
asr_engine = make_recognizer_engine(model_name)
asr_engine.CreateSpeechClient(project_id, model=model_name)
asr_engine.CreateRecognizer()
babble_transcripts = recognize_all_spin(spin_test_names, asr_engine)
# Babble_transcripts is a SpinFileTranscripts List[List[RecogResults]]
model_results[model_name] = babble_transcripts
return model_results
def score_all_models(
model_results: Dict[str, SpinFileTranscripts],
ground_truths: List[List[SpinSentence]],
test_snrs: List[float] = spin_snrs) -> Dict[str, np.ndarray]:
"""Score all QuickSIN test files (from the spin_test_names argument) against
the ground_truths. Return a dictionary of scores vs. SNRs, keyed by the
model name.
"""
model_scores = {}
print(type(model_results), model_results)
for model_name in model_results:
if model_name not in all_model_names:
continue
print(f'Scoring model {model_name}')
babble_transcripts = model_results[model_name]
scores = score_all_tests(test_snrs,
ground_truths,
babble_transcripts,
debug=True)
print('Score_all_tests returned', scores)
model_scores[model_name] = scores
return model_scores
def save_recognition_results(
recognition_results: Dict[str, SpinFileTranscripts],
recognition_json_file: str):
"""Save the recognition results for all models into a JSON file so we don't
have to query the cloud again."""
class DataclassEncoder(json.JSONEncoder):
def default(self, o):
if dataclasses.is_dataclass(o):
return dataclasses.asdict(o)
elif isinstance(o, set):
return list(o)
return super().default(o)
saved_data = {
'recognition_results': recognition_results,
'time': str(datetime.datetime.now())
}
with fsspec.open(recognition_json_file, 'w') as fp:
json.dump(saved_data, fp, cls=DataclassEncoder)
def load_recognition_results(filename: str) -> Dict[str, SpinFileTranscripts]:
"""Load the precomputed QuickSIN results from a file."""
with fsspec.open(filename, 'r') as fp:
all_results = json.load(fp)
if 'recognition_results' in all_results:
results = all_results['recognition_results']
#pylint: disable=inconsistent-quotes
print(f'Reloading recognition results saved at {all_results["time"]}')
for k in results:
# print(type(results[k]), results[k])
list_of_lists = []
for i in results[k]: