Skip to content

Commit

Permalink
Merge pull request #26 from jinmiaoluo/refactor-audio-api-code
Browse files Browse the repository at this point in the history
Fix multiple issues
  • Loading branch information
alesaccoia authored Jul 1, 2024
2 parents be70748 + ae74f7d commit 465403b
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 67 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,7 @@ poetry.toml
# LSP config files
pyrightconfig.json

# audio files created by development server
audio_files

# End of https://www.toptal.com/developers/gitignore/api/python
8 changes: 4 additions & 4 deletions client/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
display: none;
}
</style>
<script src='utils.js'></script>
<script defer src='utils.js'></script>
</head>
<body>
<h1>Transcribe a Web Audio Stream with Huggingface VAD + Whisper</h1>
Expand Down Expand Up @@ -132,11 +132,11 @@ <h1>Transcribe a Web Audio Stream with Huggingface VAD + Whisper</h1>
<option value="greek">Greek</option>
</select>
</div>
<button onclick="initWebSocket()">Connect</button>
<button id="connectButton">Connect</button>
</div>
<button id="startButton" onclick='startRecording()' disabled>Start Streaming
<button id="startButton" disabled>Start Streaming
</button>
<button id="stopButton" onclick='stopRecording()' disabled>Stop Streaming
<button id="stopButton" disabled>Stop Streaming
</button>
<div id="transcription"></div>
<br/>
Expand Down
13 changes: 13 additions & 0 deletions client/realtime-audio-processor.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class RealtimeAudioProcessor extends AudioWorkletProcessor {
constructor(options) {
super();
}

process(inputs, outputs, params) {
// ASR and VAD models typically require a mono audio.
this.port.postMessage(inputs[0][0]);
return true;
}
}

registerProcessor('realtime-audio-processor', RealtimeAudioProcessor);
172 changes: 113 additions & 59 deletions client/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,44 @@ let websocket;
let context;
let processor;
let globalStream;
let language;

const bufferSize = 4096;
let isRecording = false;

function initWebSocket() {
const websocketAddress = document.getElementById('websocketAddress');
const selectedLanguage = document.getElementById('languageSelect');
const websocketStatus = document.getElementById('webSocketStatus');
const startButton = document.getElementById('startButton');
const stopButton = document.getElementById('stopButton');
const websocketAddress = document.querySelector('#websocketAddress');
const selectedLanguage = document.querySelector('#languageSelect');
const websocketStatus = document.querySelector('#webSocketStatus');
const connectButton = document.querySelector("#connectButton");
const startButton = document.querySelector('#startButton');
const stopButton = document.querySelector('#stopButton');
const transcriptionDiv = document.querySelector('#transcription');
const languageDiv = document.querySelector('#detected_language');
const processingTimeDiv = document.querySelector('#processing_time');
const panel = document.querySelector('#silence_at_end_of_chunk_options_panel');
const selectedStrategy = document.querySelector('#bufferingStrategySelect');
const chunk_length_seconds = document.querySelector('#chunk_length_seconds');
const chunk_offset_seconds = document.querySelector('#chunk_offset_seconds');

websocketAddress.addEventListener("input", resetWebsocketHandler);

websocketAddress.addEventListener("keydown", (event) => {
if (event.key === 'Enter') {
event.preventDefault();
connectWebsocketHandler();
}
});

language = selectedLanguage.value !== 'multilingual' ? selectedLanguage.value : null;
connectButton.addEventListener("click", connectWebsocketHandler);

function resetWebsocketHandler() {
if (isRecording) {
stopRecordingHandler();
}
if (websocket.readyState === WebSocket.OPEN) {
websocket.close();
}
connectButton.disabled = false;
}

function connectWebsocketHandler() {
if (!websocketAddress.value) {
console.log("WebSocket address is required.");
return;
Expand All @@ -33,12 +57,14 @@ function initWebSocket() {
console.log("WebSocket connection established");
websocketStatus.textContent = 'Connected';
startButton.disabled = false;
connectButton.disabled = true;
};
websocket.onclose = event => {
console.log("WebSocket connection closed", event);
websocketStatus.textContent = 'Not Connected';
startButton.disabled = true;
stopButton.disabled = true;
connectButton.disabled = false;
};
websocket.onmessage = event => {
console.log("Message from server:", event.data);
Expand All @@ -48,10 +74,7 @@ function initWebSocket() {
}

function updateTranscription(transcript_data) {
const transcriptionDiv = document.getElementById('transcription');
const languageDiv = document.getElementById('detected_language');

if (transcript_data.words && transcript_data.words.length > 0) {
if (Array.isArray(transcript_data.words) && transcript_data.words.length > 0) {
// Append words with color based on their probability
transcript_data.words.forEach(wordData => {
const span = document.createElement('span');
Expand All @@ -74,47 +97,75 @@ function updateTranscription(transcript_data) {
transcriptionDiv.appendChild(document.createElement('br'));
} else {
// Fallback to plain text
transcriptionDiv.textContent += transcript_data.text + '\n';
const span = document.createElement('span');
span.textContent = transcript_data.text;
transcriptionDiv.appendChild(span);
transcriptionDiv.appendChild(document.createElement('br'));
}

// Update the language information
if (transcript_data.language && transcript_data.language_probability) {
languageDiv.textContent = transcript_data.language + ' (' + transcript_data.language_probability.toFixed(2) + ')';
} else {
languageDiv.textContent = 'Not Supported';
}

// Update the processing time, if available
const processingTimeDiv = document.getElementById('processing_time');
if (transcript_data.processing_time) {
processingTimeDiv.textContent = 'Processing time: ' + transcript_data.processing_time.toFixed(2) + ' seconds';
}
}

startButton.addEventListener("click", startRecordingHandler);

function startRecording() {
function startRecordingHandler() {
if (isRecording) return;
isRecording = true;

const AudioContext = window.AudioContext || window.webkitAudioContext;
context = new AudioContext();

navigator.mediaDevices.getUserMedia({audio: true}).then(stream => {
let onSuccess = async (stream) => {
// Push user config to server
let language = selectedLanguage.value !== 'multilingual' ? selectedLanguage.value : null;
sendAudioConfig(language);

globalStream = stream;
const input = context.createMediaStreamSource(stream);
processor = context.createScriptProcessor(bufferSize, 1, 1);
processor.onaudioprocess = e => processAudio(e);
const recordingNode = await setupRecordingWorkletNode();
recordingNode.port.onmessage = (event) => {
processAudio(event.data);
};
input.connect(recordingNode);
};
let onError = (error) => {
console.error(error);
};
navigator.mediaDevices.getUserMedia({
audio: {
echoCancellation: true,
autoGainControl: false,
noiseSuppression: true,
latency: 0
}
}).then(onSuccess, onError);

// chain up the audio graph
input.connect(processor).connect(context.destination);
// Disable start button and enable stop button
startButton.disabled = true;
stopButton.disabled = false;
}

sendAudioConfig();
}).catch(error => console.error('Error accessing microphone', error));
async function setupRecordingWorkletNode() {
await context.audioWorklet.addModule('realtime-audio-processor.js');

// Disable start button and enable stop button
document.getElementById('startButton').disabled = true;
document.getElementById('stopButton').disabled = false;
return new AudioWorkletNode(
context,
'realtime-audio-processor'
);
}

function stopRecording() {
stopButton.addEventListener("click", stopRecordingHandler);

function stopRecordingHandler() {
if (!isRecording) return;
isRecording = false;

Expand All @@ -128,42 +179,60 @@ function stopRecording() {
if (context) {
context.close().then(() => context = null);
}
document.getElementById('startButton').disabled = false;
document.getElementById('stopButton').disabled = true;
startButton.disabled = false;
stopButton.disabled = true;
}

function sendAudioConfig() {
let selectedStrategy = document.getElementById('bufferingStrategySelect').value;
function sendAudioConfig(language) {
let processingArgs = {};

if (selectedStrategy === 'silence_at_end_of_chunk') {
if (selectedStrategy.value === 'silence_at_end_of_chunk') {
processingArgs = {
chunk_length_seconds: parseFloat(document.getElementById('chunk_length_seconds').value),
chunk_offset_seconds: parseFloat(document.getElementById('chunk_offset_seconds').value)
chunk_length_seconds: parseFloat(chunk_length_seconds.value),
chunk_offset_seconds: parseFloat(chunk_offset_seconds.value)
};
}

const audioConfig = {
type: 'config',
data: {
sampleRate: context.sampleRate,
bufferSize: bufferSize,
channels: 1, // Assuming mono channel
channels: 1,
language: language,
processing_strategy: selectedStrategy,
processing_strategy: selectedStrategy.value,
processing_args: processingArgs
}
};

websocket.send(JSON.stringify(audioConfig));
}

function downsampleBuffer(buffer, inputSampleRate, outputSampleRate) {
if (inputSampleRate === outputSampleRate) {
return buffer;
function processAudio(sampleData) {
// ASR (Automatic Speech Recognition) and VAD (Voice Activity Detection)
// models typically require mono audio with a sampling rate of 16 kHz,
// represented as a signed int16 array type.
//
// Implementing changes to the sampling rate using JavaScript can reduce
// computational costs on the server.
const outputSampleRate = 16000;
const decreaseResultBuffer = decreaseSampleRate(sampleData, context.sampleRate, outputSampleRate);
const audioData = convertFloat32ToInt16(decreaseResultBuffer);

if (websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send(audioData);
}
}

function decreaseSampleRate(buffer, inputSampleRate, outputSampleRate) {
if (inputSampleRate < outputSampleRate) {
console.error("Sample rate too small.");
return;
} else if (inputSampleRate === outputSampleRate) {
return;
}

let sampleRateRatio = inputSampleRate / outputSampleRate;
let newLength = Math.round(buffer.length / sampleRateRatio);
let newLength = Math.ceil(buffer.length / sampleRateRatio);
let result = new Float32Array(newLength);
let offsetResult = 0;
let offsetBuffer = 0;
Expand All @@ -181,19 +250,6 @@ function downsampleBuffer(buffer, inputSampleRate, outputSampleRate) {
return result;
}

function processAudio(e) {
const inputSampleRate = context.sampleRate;
const outputSampleRate = 16000; // Target sample rate

const left = e.inputBuffer.getChannelData(0);
const downsampledBuffer = downsampleBuffer(left, inputSampleRate, outputSampleRate);
const audioData = convertFloat32ToInt16(downsampledBuffer);

if (websocket && websocket.readyState === WebSocket.OPEN) {
websocket.send(audioData);
}
}

function convertFloat32ToInt16(buffer) {
let l = buffer.length;
const buf = new Int16Array(l);
Expand All @@ -207,9 +263,7 @@ function convertFloat32ToInt16(buffer) {
// window.onload = initWebSocket;

function toggleBufferingStrategyPanel() {
let selectedStrategy = document.getElementById('bufferingStrategySelect').value;
let panel = document.getElementById('silence_at_end_of_chunk_options_panel');
if (selectedStrategy === 'silence_at_end_of_chunk') {
if (selectedStrategy.value === 'silence_at_end_of_chunk') {
panel.classList.remove('hidden');
} else {
panel.classList.add('hidden');
Expand Down
6 changes: 5 additions & 1 deletion src/asr/whisper_asr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

import torch
from transformers import pipeline

from src.audio_utils import save_audio_to_file
Expand All @@ -9,9 +10,12 @@

class WhisperASR(ASRInterface):
def __init__(self, **kwargs):
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = kwargs.get("model_name", "openai/whisper-large-v3")
self.asr_pipeline = pipeline(
"automatic-speech-recognition", model=model_name
"automatic-speech-recognition",
model=model_name,
device=device,
)

async def transcribe(self, client):
Expand Down
4 changes: 1 addition & 3 deletions src/audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ async def save_audio_to_file(
"""
Saves the audio data to a file.
:param client_id: Unique identifier for the client.
:param audio_data: The audio data to save.
:param file_counters: Dictionary to keep track of file counts for each
client.
:param file_name: The name of the file.
:param audio_dir: Directory where audio files will be saved.
:param audio_format: Format of the audio file.
:return: Path to the saved audio file.
Expand Down
11 changes: 11 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import asyncio
import json
import logging

from src.asr.asr_factory import ASRFactory
from src.vad.vad_factory import VADFactory
Expand Down Expand Up @@ -59,12 +60,22 @@ def parse_args():
default=None,
help="The path to the SSL key file if using secure websockets",
)
parser.add_argument(
"--log-level",
type=str,
default="error",
choices=["debug", "info", "warning", "error"],
help="Logging level: debug, info, warning, error. default: error",
)
return parser.parse_args()


def main():
args = parse_args()

logging.basicConfig()
logging.getLogger().setLevel(args.log_level.upper())

try:
vad_args = json.loads(args.vad_args)
asr_args = json.loads(args.asr_args)
Expand Down
Loading

0 comments on commit 465403b

Please sign in to comment.