diff --git a/.gitignore b/.gitignore index ad4a1f1..d0a976e 100644 --- a/.gitignore +++ b/.gitignore @@ -173,4 +173,7 @@ poetry.toml # LSP config files pyrightconfig.json +# audio files created by development server +audio_files + # End of https://www.toptal.com/developers/gitignore/api/python diff --git a/client/index.html b/client/index.html index f17860c..7807e86 100644 --- a/client/index.html +++ b/client/index.html @@ -75,7 +75,7 @@ display: none; } - +

Transcribe a Web Audio Stream with Huggingface VAD + Whisper

@@ -132,11 +132,11 @@

Transcribe a Web Audio Stream with Huggingface VAD + Whisper

- + - -

diff --git a/client/realtime-audio-processor.js b/client/realtime-audio-processor.js new file mode 100644 index 0000000..b925913 --- /dev/null +++ b/client/realtime-audio-processor.js @@ -0,0 +1,13 @@ +class RealtimeAudioProcessor extends AudioWorkletProcessor { + constructor(options) { + super(); + } + + process(inputs, outputs, params) { + // ASR and VAD models typically require a mono audio. + this.port.postMessage(inputs[0][0]); + return true; + } +} + +registerProcessor('realtime-audio-processor', RealtimeAudioProcessor); diff --git a/client/utils.js b/client/utils.js index 0cbae17..812adc3 100644 --- a/client/utils.js +++ b/client/utils.js @@ -9,20 +9,44 @@ let websocket; let context; let processor; let globalStream; -let language; - -const bufferSize = 4096; let isRecording = false; -function initWebSocket() { - const websocketAddress = document.getElementById('websocketAddress'); - const selectedLanguage = document.getElementById('languageSelect'); - const websocketStatus = document.getElementById('webSocketStatus'); - const startButton = document.getElementById('startButton'); - const stopButton = document.getElementById('stopButton'); +const websocketAddress = document.querySelector('#websocketAddress'); +const selectedLanguage = document.querySelector('#languageSelect'); +const websocketStatus = document.querySelector('#webSocketStatus'); +const connectButton = document.querySelector("#connectButton"); +const startButton = document.querySelector('#startButton'); +const stopButton = document.querySelector('#stopButton'); +const transcriptionDiv = document.querySelector('#transcription'); +const languageDiv = document.querySelector('#detected_language'); +const processingTimeDiv = document.querySelector('#processing_time'); +const panel = document.querySelector('#silence_at_end_of_chunk_options_panel'); +const selectedStrategy = document.querySelector('#bufferingStrategySelect'); +const chunk_length_seconds = document.querySelector('#chunk_length_seconds'); +const chunk_offset_seconds = document.querySelector('#chunk_offset_seconds'); + +websocketAddress.addEventListener("input", resetWebsocketHandler); + +websocketAddress.addEventListener("keydown", (event) => { + if (event.key === 'Enter') { + event.preventDefault(); + connectWebsocketHandler(); + } +}); - language = selectedLanguage.value !== 'multilingual' ? selectedLanguage.value : null; +connectButton.addEventListener("click", connectWebsocketHandler); +function resetWebsocketHandler() { + if (isRecording) { + stopRecordingHandler(); + } + if (websocket.readyState === WebSocket.OPEN) { + websocket.close(); + } + connectButton.disabled = false; +} + +function connectWebsocketHandler() { if (!websocketAddress.value) { console.log("WebSocket address is required."); return; @@ -33,12 +57,14 @@ function initWebSocket() { console.log("WebSocket connection established"); websocketStatus.textContent = 'Connected'; startButton.disabled = false; + connectButton.disabled = true; }; websocket.onclose = event => { console.log("WebSocket connection closed", event); websocketStatus.textContent = 'Not Connected'; startButton.disabled = true; stopButton.disabled = true; + connectButton.disabled = false; }; websocket.onmessage = event => { console.log("Message from server:", event.data); @@ -48,10 +74,7 @@ function initWebSocket() { } function updateTranscription(transcript_data) { - const transcriptionDiv = document.getElementById('transcription'); - const languageDiv = document.getElementById('detected_language'); - - if (transcript_data.words && transcript_data.words.length > 0) { + if (Array.isArray(transcript_data.words) && transcript_data.words.length > 0) { // Append words with color based on their probability transcript_data.words.forEach(wordData => { const span = document.createElement('span'); @@ -74,47 +97,75 @@ function updateTranscription(transcript_data) { transcriptionDiv.appendChild(document.createElement('br')); } else { // Fallback to plain text - transcriptionDiv.textContent += transcript_data.text + '\n'; + const span = document.createElement('span'); + span.textContent = transcript_data.text; + transcriptionDiv.appendChild(span); + transcriptionDiv.appendChild(document.createElement('br')); } // Update the language information if (transcript_data.language && transcript_data.language_probability) { languageDiv.textContent = transcript_data.language + ' (' + transcript_data.language_probability.toFixed(2) + ')'; + } else { + languageDiv.textContent = 'Not Supported'; } // Update the processing time, if available - const processingTimeDiv = document.getElementById('processing_time'); if (transcript_data.processing_time) { processingTimeDiv.textContent = 'Processing time: ' + transcript_data.processing_time.toFixed(2) + ' seconds'; } } +startButton.addEventListener("click", startRecordingHandler); -function startRecording() { +function startRecordingHandler() { if (isRecording) return; isRecording = true; - const AudioContext = window.AudioContext || window.webkitAudioContext; context = new AudioContext(); - navigator.mediaDevices.getUserMedia({audio: true}).then(stream => { + let onSuccess = async (stream) => { + // Push user config to server + let language = selectedLanguage.value !== 'multilingual' ? selectedLanguage.value : null; + sendAudioConfig(language); + globalStream = stream; const input = context.createMediaStreamSource(stream); - processor = context.createScriptProcessor(bufferSize, 1, 1); - processor.onaudioprocess = e => processAudio(e); + const recordingNode = await setupRecordingWorkletNode(); + recordingNode.port.onmessage = (event) => { + processAudio(event.data); + }; + input.connect(recordingNode); + }; + let onError = (error) => { + console.error(error); + }; + navigator.mediaDevices.getUserMedia({ + audio: { + echoCancellation: true, + autoGainControl: false, + noiseSuppression: true, + latency: 0 + } + }).then(onSuccess, onError); - // chain up the audio graph - input.connect(processor).connect(context.destination); + // Disable start button and enable stop button + startButton.disabled = true; + stopButton.disabled = false; +} - sendAudioConfig(); - }).catch(error => console.error('Error accessing microphone', error)); +async function setupRecordingWorkletNode() { + await context.audioWorklet.addModule('realtime-audio-processor.js'); - // Disable start button and enable stop button - document.getElementById('startButton').disabled = true; - document.getElementById('stopButton').disabled = false; + return new AudioWorkletNode( + context, + 'realtime-audio-processor' + ); } -function stopRecording() { +stopButton.addEventListener("click", stopRecordingHandler); + +function stopRecordingHandler() { if (!isRecording) return; isRecording = false; @@ -128,18 +179,17 @@ function stopRecording() { if (context) { context.close().then(() => context = null); } - document.getElementById('startButton').disabled = false; - document.getElementById('stopButton').disabled = true; + startButton.disabled = false; + stopButton.disabled = true; } -function sendAudioConfig() { - let selectedStrategy = document.getElementById('bufferingStrategySelect').value; +function sendAudioConfig(language) { let processingArgs = {}; - if (selectedStrategy === 'silence_at_end_of_chunk') { + if (selectedStrategy.value === 'silence_at_end_of_chunk') { processingArgs = { - chunk_length_seconds: parseFloat(document.getElementById('chunk_length_seconds').value), - chunk_offset_seconds: parseFloat(document.getElementById('chunk_offset_seconds').value) + chunk_length_seconds: parseFloat(chunk_length_seconds.value), + chunk_offset_seconds: parseFloat(chunk_offset_seconds.value) }; } @@ -147,10 +197,9 @@ function sendAudioConfig() { type: 'config', data: { sampleRate: context.sampleRate, - bufferSize: bufferSize, - channels: 1, // Assuming mono channel + channels: 1, language: language, - processing_strategy: selectedStrategy, + processing_strategy: selectedStrategy.value, processing_args: processingArgs } }; @@ -158,12 +207,32 @@ function sendAudioConfig() { websocket.send(JSON.stringify(audioConfig)); } -function downsampleBuffer(buffer, inputSampleRate, outputSampleRate) { - if (inputSampleRate === outputSampleRate) { - return buffer; +function processAudio(sampleData) { + // ASR (Automatic Speech Recognition) and VAD (Voice Activity Detection) + // models typically require mono audio with a sampling rate of 16 kHz, + // represented as a signed int16 array type. + // + // Implementing changes to the sampling rate using JavaScript can reduce + // computational costs on the server. + const outputSampleRate = 16000; + const decreaseResultBuffer = decreaseSampleRate(sampleData, context.sampleRate, outputSampleRate); + const audioData = convertFloat32ToInt16(decreaseResultBuffer); + + if (websocket && websocket.readyState === WebSocket.OPEN) { + websocket.send(audioData); + } +} + +function decreaseSampleRate(buffer, inputSampleRate, outputSampleRate) { + if (inputSampleRate < outputSampleRate) { + console.error("Sample rate too small."); + return; + } else if (inputSampleRate === outputSampleRate) { + return; } + let sampleRateRatio = inputSampleRate / outputSampleRate; - let newLength = Math.round(buffer.length / sampleRateRatio); + let newLength = Math.ceil(buffer.length / sampleRateRatio); let result = new Float32Array(newLength); let offsetResult = 0; let offsetBuffer = 0; @@ -181,19 +250,6 @@ function downsampleBuffer(buffer, inputSampleRate, outputSampleRate) { return result; } -function processAudio(e) { - const inputSampleRate = context.sampleRate; - const outputSampleRate = 16000; // Target sample rate - - const left = e.inputBuffer.getChannelData(0); - const downsampledBuffer = downsampleBuffer(left, inputSampleRate, outputSampleRate); - const audioData = convertFloat32ToInt16(downsampledBuffer); - - if (websocket && websocket.readyState === WebSocket.OPEN) { - websocket.send(audioData); - } -} - function convertFloat32ToInt16(buffer) { let l = buffer.length; const buf = new Int16Array(l); @@ -207,9 +263,7 @@ function convertFloat32ToInt16(buffer) { // window.onload = initWebSocket; function toggleBufferingStrategyPanel() { - let selectedStrategy = document.getElementById('bufferingStrategySelect').value; - let panel = document.getElementById('silence_at_end_of_chunk_options_panel'); - if (selectedStrategy === 'silence_at_end_of_chunk') { + if (selectedStrategy.value === 'silence_at_end_of_chunk') { panel.classList.remove('hidden'); } else { panel.classList.add('hidden'); diff --git a/src/asr/whisper_asr.py b/src/asr/whisper_asr.py index 2cb05bf..b472e43 100644 --- a/src/asr/whisper_asr.py +++ b/src/asr/whisper_asr.py @@ -1,5 +1,6 @@ import os +import torch from transformers import pipeline from src.audio_utils import save_audio_to_file @@ -9,9 +10,12 @@ class WhisperASR(ASRInterface): def __init__(self, **kwargs): + device = "cuda" if torch.cuda.is_available() else "cpu" model_name = kwargs.get("model_name", "openai/whisper-large-v3") self.asr_pipeline = pipeline( - "automatic-speech-recognition", model=model_name + "automatic-speech-recognition", + model=model_name, + device=device, ) async def transcribe(self, client): diff --git a/src/audio_utils.py b/src/audio_utils.py index 9ee4c7d..f9aa203 100644 --- a/src/audio_utils.py +++ b/src/audio_utils.py @@ -8,10 +8,8 @@ async def save_audio_to_file( """ Saves the audio data to a file. - :param client_id: Unique identifier for the client. :param audio_data: The audio data to save. - :param file_counters: Dictionary to keep track of file counts for each - client. + :param file_name: The name of the file. :param audio_dir: Directory where audio files will be saved. :param audio_format: Format of the audio file. :return: Path to the saved audio file. diff --git a/src/main.py b/src/main.py index 0151d1f..fad9c2d 100644 --- a/src/main.py +++ b/src/main.py @@ -1,6 +1,7 @@ import argparse import asyncio import json +import logging from src.asr.asr_factory import ASRFactory from src.vad.vad_factory import VADFactory @@ -59,12 +60,22 @@ def parse_args(): default=None, help="The path to the SSL key file if using secure websockets", ) + parser.add_argument( + "--log-level", + type=str, + default="error", + choices=["debug", "info", "warning", "error"], + help="Logging level: debug, info, warning, error. default: error", + ) return parser.parse_args() def main(): args = parse_args() + logging.basicConfig() + logging.getLogger().setLevel(args.log_level.upper()) + try: vad_args = json.loads(args.vad_args) asr_args = json.loads(args.asr_args) diff --git a/src/server.py b/src/server.py index 00932df..809c7ad 100644 --- a/src/server.py +++ b/src/server.py @@ -1,4 +1,5 @@ import json +import logging import ssl import uuid @@ -57,6 +58,7 @@ async def handle_audio(self, client, websocket): config = json.loads(message) if config.get("type") == "config": client.update_config(config["data"]) + logging.debug(f"Updated config: {client.config}") continue else: print(f"Unexpected message type from {client.client_id}")