-
Notifications
You must be signed in to change notification settings - Fork 303
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #136 from makaveli10/add_tests
Add tests.
- Loading branch information
Showing
10 changed files
with
327 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,5 @@ kaldialign | |
soundfile | ||
ffmpeg-python | ||
scipy | ||
jiwer | ||
evaluate |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import json | ||
import os | ||
import scipy | ||
import websocket | ||
import unittest | ||
from unittest.mock import patch, MagicMock | ||
from whisper_live.client import TranscriptionClient, resample | ||
|
||
|
||
class BaseTestCase(unittest.TestCase): | ||
@patch('whisper_live.client.websocket.WebSocketApp') | ||
@patch('whisper_live.client.pyaudio.PyAudio') | ||
def setUp(self, mock_pyaudio, mock_websocket): | ||
self.mock_pyaudio_instance = MagicMock() | ||
mock_pyaudio.return_value = self.mock_pyaudio_instance | ||
self.mock_stream = MagicMock() | ||
self.mock_pyaudio_instance.open.return_value = self.mock_stream | ||
|
||
self.mock_ws_app = mock_websocket.return_value | ||
self.mock_ws_app.send = MagicMock() | ||
|
||
self.client = TranscriptionClient(host='localhost', port=9090, lang="en").client | ||
|
||
self.mock_pyaudio = mock_pyaudio | ||
self.mock_websocket = mock_websocket | ||
|
||
def tearDown(self): | ||
self.client.close_websocket() | ||
self.mock_pyaudio.stop() | ||
self.mock_websocket.stop() | ||
del self.client | ||
|
||
|
||
class TestClientWebSocketCommunication(BaseTestCase): | ||
def test_websocket_communication(self): | ||
expected_url = 'ws://localhost:9090' | ||
self.mock_websocket.assert_called() | ||
self.assertEqual(self.mock_websocket.call_args[0][0], expected_url) | ||
|
||
|
||
class TestClientCallbacks(BaseTestCase): | ||
def test_on_open(self): | ||
expected_message = json.dumps({ | ||
"uid": self.client.uid, | ||
"language": self.client.language, | ||
"task": self.client.task, | ||
"model": self.client.model, | ||
}) | ||
self.client.on_open(self.mock_ws_app) | ||
self.mock_ws_app.send.assert_called_with(expected_message) | ||
|
||
def test_on_message(self): | ||
message = json.dumps( | ||
{ | ||
"uid": self.client.uid, | ||
"message": "SERVER_READY", | ||
"backend": "faster_whisper" | ||
} | ||
) | ||
self.client.on_message(self.mock_ws_app, message) | ||
|
||
message = json.dumps({ | ||
"uid": self.client.uid, | ||
"segments": [ | ||
{"start": 0, "end": 1, "text": "Test transcript"}, | ||
{"start": 1, "end": 2, "text": "Test transcript 2"}, | ||
{"start": 2, "end": 3, "text": "Test transcript 3"} | ||
] | ||
}) | ||
self.client.on_message(self.mock_ws_app, message) | ||
|
||
# Assert that the transcript was updated correctly | ||
self.assertEqual(len(self.client.transcript), 2) | ||
self.assertEqual(self.client.transcript[1]['text'], "Test transcript 2") | ||
|
||
def test_on_close(self): | ||
close_status_code = 1000 | ||
close_msg = "Normal closure" | ||
self.client.on_close(self.mock_ws_app, close_status_code, close_msg) | ||
|
||
self.assertFalse(self.client.recording) | ||
self.assertFalse(self.client.server_error) | ||
self.assertFalse(self.client.waiting) | ||
|
||
def test_on_error(self): | ||
error_message = "Test Error" | ||
self.client.on_error(self.mock_ws_app, error_message) | ||
|
||
self.assertTrue(self.client.server_error) | ||
self.assertEqual(self.client.error_message, error_message) | ||
|
||
|
||
class TestAudioResampling(unittest.TestCase): | ||
def test_resample_audio(self): | ||
original_audio = "assets/jfk.flac" | ||
expected_sr = 16000 | ||
resampled_audio = resample(original_audio, expected_sr) | ||
|
||
sr, _ = scipy.io.wavfile.read(resampled_audio) | ||
self.assertEqual(sr, expected_sr) | ||
|
||
os.remove(resampled_audio) | ||
|
||
|
||
class TestSendingAudioPacket(BaseTestCase): | ||
def test_send_packet(self): | ||
mock_audio_packet = b'\x00\x01\x02\x03' | ||
self.client.send_packet_to_server(mock_audio_packet) | ||
self.client.client_socket.send.assert_called_with(mock_audio_packet, websocket.ABNF.OPCODE_BINARY) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import subprocess | ||
import time | ||
import json | ||
import unittest | ||
from unittest import mock | ||
|
||
import numpy as np | ||
import evaluate | ||
from whisper_live.server import TranscriptionServer | ||
from whisper_live.client import TranscriptionClient | ||
from whisper.normalizers import EnglishTextNormalizer | ||
|
||
|
||
class TestTranscriptionServerInitialization(unittest.TestCase): | ||
def test_initialization(self): | ||
server = TranscriptionServer() | ||
self.assertEqual(server.max_clients, 4) | ||
self.assertEqual(server.max_connection_time, 600) | ||
self.assertDictEqual(server.clients, {}) | ||
self.assertDictEqual(server.websockets, {}) | ||
self.assertDictEqual(server.clients_start_time, {}) | ||
|
||
|
||
class TestGetWaitTime(unittest.TestCase): | ||
def setUp(self): | ||
self.server = TranscriptionServer() | ||
self.server.clients_start_time = { | ||
'client1': time.time() - 120, | ||
'client2': time.time() - 300 | ||
} | ||
self.server.max_connection_time = 600 | ||
|
||
def test_get_wait_time(self): | ||
expected_wait_time = (600 - (time.time() - self.server.clients_start_time['client2'])) / 60 | ||
print(self.server.get_wait_time(), expected_wait_time) | ||
self.assertAlmostEqual(self.server.get_wait_time(), expected_wait_time, places=2) | ||
|
||
|
||
class TestServerConnection(unittest.TestCase): | ||
def setUp(self): | ||
self.server = TranscriptionServer() | ||
|
||
@mock.patch('websockets.WebSocketCommonProtocol') | ||
def test_connection(self, mock_websocket): | ||
mock_websocket.recv.return_value = json.dumps({ | ||
'uid': 'test_client', | ||
'language': 'en', | ||
'task': 'transcribe', | ||
'model': 'tiny.en' | ||
}) | ||
self.server.recv_audio(mock_websocket, "faster_whisper") | ||
|
||
|
||
@mock.patch('websockets.WebSocketCommonProtocol') | ||
def test_recv_audio_exception_handling(self, mock_websocket): | ||
mock_websocket.recv.side_effect = [json.dumps({ | ||
'uid': 'test_client', | ||
'language': 'en', | ||
'task': 'transcribe', | ||
'model': 'tiny.en' | ||
}), np.array([1, 2, 3]).tobytes()] | ||
|
||
with self.assertLogs(level="ERROR"): | ||
self.server.recv_audio(mock_websocket, "faster_whisper") | ||
|
||
self.assertNotIn(mock_websocket, self.server.clients) | ||
|
||
|
||
class TestServerInferenceAccuracy(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.server_process = subprocess.Popen(["python", "run_server.py"]) # Adjust the command as needed | ||
time.sleep(2) | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
cls.server_process.terminate() | ||
cls.server_process.wait() | ||
|
||
@mock.patch('pyaudio.PyAudio') | ||
def setUp(self, mock_pyaudio): | ||
self.mock_pyaudio = mock_pyaudio.return_value | ||
self.mock_stream = mock.MagicMock() | ||
self.mock_pyaudio.open.return_value = self.mock_stream | ||
self.metric = evaluate.load("wer") | ||
self.normalizer = EnglishTextNormalizer() | ||
self.client = TranscriptionClient( | ||
"localhost", "9090", model="base.en", lang="en", | ||
) | ||
|
||
def test_inference(self): | ||
gt = "And so my fellow Americans, ask not, what your country can do for you. Ask what you can do for your country!" | ||
self.client("assets/jfk.flac") | ||
with open("output.srt", "r") as f: | ||
lines = f.readlines() | ||
prediction = " ".join([l.strip() for l in lines[2::4]]) | ||
prediction_normalized = self.normalizer(prediction) | ||
gt_normalized = self.normalizer(gt) | ||
|
||
# calculate WER | ||
wer = self.metric.compute( | ||
predictions=[prediction_normalized], | ||
references=[gt_normalized] | ||
) | ||
self.assertLess(wer, 0.05) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import unittest | ||
import numpy as np | ||
import torch | ||
import scipy.io as sio | ||
from whisper_live.tensorrt_utils import load_audio | ||
from whisper_live.vad import VoiceActivityDetection | ||
|
||
|
||
class TestVoiceActivityDetection(unittest.TestCase): | ||
def setUp(self): | ||
self.vad = VoiceActivityDetection() | ||
self.sample_rate = 16000 | ||
|
||
def generate_silence(self, duration_seconds): | ||
return np.zeros(int(self.sample_rate * duration_seconds), dtype=np.float32) | ||
|
||
def load_speech_segment(self, filepath): | ||
return load_audio(filepath) | ||
|
||
def test_vad_silence_detection(self): | ||
silence = self.generate_silence(3) | ||
speech_prob = self.vad(torch.from_numpy(silence.copy()), self.sample_rate).item() | ||
self.assertLess(speech_prob, 0.5, "VAD incorrectly identified silence as speech.") | ||
|
||
def test_vad_speech_detection(self): | ||
audio_tensor = torch.from_numpy(load_audio("assets/jfk.flac")) | ||
speech_prob = self.vad(audio_tensor, self.sample_rate).item() | ||
self.assertGreater(speech_prob, 0.5, "VAD failed to identify speech segment.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.