Skip to content

Commit

Permalink
Merge pull request #136 from makaveli10/add_tests
Browse files Browse the repository at this point in the history
Add tests.
  • Loading branch information
zoq authored Feb 19, 2024
2 parents e40414a + 14974af commit 5e1174f
Show file tree
Hide file tree
Showing 10 changed files with 327 additions and 49 deletions.
115 changes: 76 additions & 39 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: CI
name: Test & Build CI/CD

on:
push:
Expand All @@ -7,46 +7,83 @@ on:
tags:
- v*
pull_request:
branches:
- main
branches: [ main ]
types: [opened, synchronize, reopened]

jobs:
build-and-push-package:
test:
runs-on: ubuntu-latest
timeout-minutes: 60
strategy:
matrix:
python-version: [3.8, 3.9, '3.10', '3.11']
steps:
- uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Cache Python dependencies
uses: actions/cache@v2
with:
path: |
~/.cache/pip
!~/.cache/pip/log
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ hashFiles('requirements/server.txt', 'requirements/client.txt') }}
restore-keys: |
${{ runner.os }}-pip-${{ matrix.python-version }}-
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg portaudio19-dev

- name: Install Python dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements/server.txt --extra-index-url https://download.pytorch.org/whl/cpu
pip install -r requirements/client.txt
- name: Run tests
run: |
echo "Running tests with Python ${{ matrix.python-version }}"
python -m unittest discover -s tests
build-and-push:
needs: test
runs-on: ubuntu-latest
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
steps:
- name: Check Out Repository
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.8

- name: Set up FFmpeg
uses: FedericoCarboni/setup-ffmpeg@v2

- name: Install Additional requirements
run: |
sudo apt-get -y install portaudio19-dev wget
shell: bash

- name: Install Client Requirements
run: pip install -r requirements/client.txt

- name: Install Server Requirements
run: pip install -r requirements/server.txt

- name: Install Wheel for build
run: pip install wheel twine

- name: Build wheel
run: |
python setup.py sdist bdist_wheel
- name: Push package on Test PyPI
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
- uses: actions/checkout@v2

- name: Set up Python 3.8
uses: actions/setup-python@v2
with:
python-version: 3.8

- name: Cache Python dependencies
uses: actions/cache@v2
with:
path: |
~/.cache/pip
!~/.cache/pip/log
key: ubuntu-latest-pip-3.8-${{ hashFiles('requirements/server.txt', 'requirements/client.txt') }}
restore-keys: |
ubuntu-latest-pip-3.8-
- name: Install system dependencies
run: sudo apt-get update && sudo apt-get install -y ffmpeg portaudio19-dev

- name: Install Python dependencies
run: |
pip install -r requirements/server.txt
pip install -r requirements/client.txt
- name: Build package
run: python setup.py sdist bdist_wheel

- name: Publish package to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
user: __token__
password: ${{ secrets.PYPI_API_TOKEN }}
Binary file added assets/jfk.flac
Binary file not shown.
2 changes: 2 additions & 0 deletions requirements/server.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ kaldialign
soundfile
ffmpeg-python
scipy
jiwer
evaluate
Empty file added tests/__init__.py
Empty file.
109 changes: 109 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import json
import os
import scipy
import websocket
import unittest
from unittest.mock import patch, MagicMock
from whisper_live.client import TranscriptionClient, resample


class BaseTestCase(unittest.TestCase):
@patch('whisper_live.client.websocket.WebSocketApp')
@patch('whisper_live.client.pyaudio.PyAudio')
def setUp(self, mock_pyaudio, mock_websocket):
self.mock_pyaudio_instance = MagicMock()
mock_pyaudio.return_value = self.mock_pyaudio_instance
self.mock_stream = MagicMock()
self.mock_pyaudio_instance.open.return_value = self.mock_stream

self.mock_ws_app = mock_websocket.return_value
self.mock_ws_app.send = MagicMock()

self.client = TranscriptionClient(host='localhost', port=9090, lang="en").client

self.mock_pyaudio = mock_pyaudio
self.mock_websocket = mock_websocket

def tearDown(self):
self.client.close_websocket()
self.mock_pyaudio.stop()
self.mock_websocket.stop()
del self.client


class TestClientWebSocketCommunication(BaseTestCase):
def test_websocket_communication(self):
expected_url = 'ws://localhost:9090'
self.mock_websocket.assert_called()
self.assertEqual(self.mock_websocket.call_args[0][0], expected_url)


class TestClientCallbacks(BaseTestCase):
def test_on_open(self):
expected_message = json.dumps({
"uid": self.client.uid,
"language": self.client.language,
"task": self.client.task,
"model": self.client.model,
})
self.client.on_open(self.mock_ws_app)
self.mock_ws_app.send.assert_called_with(expected_message)

def test_on_message(self):
message = json.dumps(
{
"uid": self.client.uid,
"message": "SERVER_READY",
"backend": "faster_whisper"
}
)
self.client.on_message(self.mock_ws_app, message)

message = json.dumps({
"uid": self.client.uid,
"segments": [
{"start": 0, "end": 1, "text": "Test transcript"},
{"start": 1, "end": 2, "text": "Test transcript 2"},
{"start": 2, "end": 3, "text": "Test transcript 3"}
]
})
self.client.on_message(self.mock_ws_app, message)

# Assert that the transcript was updated correctly
self.assertEqual(len(self.client.transcript), 2)
self.assertEqual(self.client.transcript[1]['text'], "Test transcript 2")

def test_on_close(self):
close_status_code = 1000
close_msg = "Normal closure"
self.client.on_close(self.mock_ws_app, close_status_code, close_msg)

self.assertFalse(self.client.recording)
self.assertFalse(self.client.server_error)
self.assertFalse(self.client.waiting)

def test_on_error(self):
error_message = "Test Error"
self.client.on_error(self.mock_ws_app, error_message)

self.assertTrue(self.client.server_error)
self.assertEqual(self.client.error_message, error_message)


class TestAudioResampling(unittest.TestCase):
def test_resample_audio(self):
original_audio = "assets/jfk.flac"
expected_sr = 16000
resampled_audio = resample(original_audio, expected_sr)

sr, _ = scipy.io.wavfile.read(resampled_audio)
self.assertEqual(sr, expected_sr)

os.remove(resampled_audio)


class TestSendingAudioPacket(BaseTestCase):
def test_send_packet(self):
mock_audio_packet = b'\x00\x01\x02\x03'
self.client.send_packet_to_server(mock_audio_packet)
self.client.client_socket.send.assert_called_with(mock_audio_packet, websocket.ABNF.OPCODE_BINARY)
105 changes: 105 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import subprocess
import time
import json
import unittest
from unittest import mock

import numpy as np
import evaluate
from whisper_live.server import TranscriptionServer
from whisper_live.client import TranscriptionClient
from whisper.normalizers import EnglishTextNormalizer


class TestTranscriptionServerInitialization(unittest.TestCase):
def test_initialization(self):
server = TranscriptionServer()
self.assertEqual(server.max_clients, 4)
self.assertEqual(server.max_connection_time, 600)
self.assertDictEqual(server.clients, {})
self.assertDictEqual(server.websockets, {})
self.assertDictEqual(server.clients_start_time, {})


class TestGetWaitTime(unittest.TestCase):
def setUp(self):
self.server = TranscriptionServer()
self.server.clients_start_time = {
'client1': time.time() - 120,
'client2': time.time() - 300
}
self.server.max_connection_time = 600

def test_get_wait_time(self):
expected_wait_time = (600 - (time.time() - self.server.clients_start_time['client2'])) / 60
print(self.server.get_wait_time(), expected_wait_time)
self.assertAlmostEqual(self.server.get_wait_time(), expected_wait_time, places=2)


class TestServerConnection(unittest.TestCase):
def setUp(self):
self.server = TranscriptionServer()

@mock.patch('websockets.WebSocketCommonProtocol')
def test_connection(self, mock_websocket):
mock_websocket.recv.return_value = json.dumps({
'uid': 'test_client',
'language': 'en',
'task': 'transcribe',
'model': 'tiny.en'
})
self.server.recv_audio(mock_websocket, "faster_whisper")


@mock.patch('websockets.WebSocketCommonProtocol')
def test_recv_audio_exception_handling(self, mock_websocket):
mock_websocket.recv.side_effect = [json.dumps({
'uid': 'test_client',
'language': 'en',
'task': 'transcribe',
'model': 'tiny.en'
}), np.array([1, 2, 3]).tobytes()]

with self.assertLogs(level="ERROR"):
self.server.recv_audio(mock_websocket, "faster_whisper")

self.assertNotIn(mock_websocket, self.server.clients)


class TestServerInferenceAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.server_process = subprocess.Popen(["python", "run_server.py"]) # Adjust the command as needed
time.sleep(2)

@classmethod
def tearDownClass(cls):
cls.server_process.terminate()
cls.server_process.wait()

@mock.patch('pyaudio.PyAudio')
def setUp(self, mock_pyaudio):
self.mock_pyaudio = mock_pyaudio.return_value
self.mock_stream = mock.MagicMock()
self.mock_pyaudio.open.return_value = self.mock_stream
self.metric = evaluate.load("wer")
self.normalizer = EnglishTextNormalizer()
self.client = TranscriptionClient(
"localhost", "9090", model="base.en", lang="en",
)

def test_inference(self):
gt = "And so my fellow Americans, ask not, what your country can do for you. Ask what you can do for your country!"
self.client("assets/jfk.flac")
with open("output.srt", "r") as f:
lines = f.readlines()
prediction = " ".join([l.strip() for l in lines[2::4]])
prediction_normalized = self.normalizer(prediction)
gt_normalized = self.normalizer(gt)

# calculate WER
wer = self.metric.compute(
predictions=[prediction_normalized],
references=[gt_normalized]
)
self.assertLess(wer, 0.05)
28 changes: 28 additions & 0 deletions tests/test_vad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import unittest
import numpy as np
import torch
import scipy.io as sio
from whisper_live.tensorrt_utils import load_audio
from whisper_live.vad import VoiceActivityDetection


class TestVoiceActivityDetection(unittest.TestCase):
def setUp(self):
self.vad = VoiceActivityDetection()
self.sample_rate = 16000

def generate_silence(self, duration_seconds):
return np.zeros(int(self.sample_rate * duration_seconds), dtype=np.float32)

def load_speech_segment(self, filepath):
return load_audio(filepath)

def test_vad_silence_detection(self):
silence = self.generate_silence(3)
speech_prob = self.vad(torch.from_numpy(silence.copy()), self.sample_rate).item()
self.assertLess(speech_prob, 0.5, "VAD incorrectly identified silence as speech.")

def test_vad_speech_detection(self):
audio_tensor = torch.from_numpy(load_audio("assets/jfk.flac"))
speech_prob = self.vad(audio_tensor, self.sample_rate).item()
self.assertGreater(speech_prob, 0.5, "VAD failed to identify speech segment.")
7 changes: 6 additions & 1 deletion whisper_live/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,15 @@ def on_message(self, ws, message):
print(element)

def on_error(self, ws, error):
print(error)
print(f"[ERROR] WebSocket Error: {error}")
self.server_error = True
self.error_message = error

def on_close(self, ws, close_status_code, close_msg):
print(f"[INFO]: Websocket connection closed: {close_status_code}: {close_msg}")
self.recording = False
self.server_error = False
self.waiting = False

def on_open(self, ws):
"""
Expand Down
Loading

0 comments on commit 5e1174f

Please sign in to comment.