Skip to content

Commit

Permalink
Further refinement of TensorRT-LLM backend based on WhisperS2T
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleks committed Mar 29, 2024
1 parent 7778357 commit 033b008
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 11 deletions.
1 change: 1 addition & 0 deletions notebooks/async_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json

import aiohttp

headers = {"accept": "application/json", "Content-Type": "application/json"}
Expand Down
1 change: 1 addition & 0 deletions notebooks/audio_url_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json

import requests

headers = {"accept": "application/json", "Content-Type": "application/json"}
Expand Down
3 changes: 3 additions & 0 deletions notebooks/live_inference.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Test the live endpoint."""

import asyncio

import websockets


async def test_websocket_endpoint():
uri = "ws://localhost:5001/api/v1/live?source_lang=en" # Replace with the actual WebSocket URL
async with websockets.connect(uri) as websocket:
Expand All @@ -16,5 +18,6 @@ async def test_websocket_endpoint():
except websockets.exceptions.ConnectionClosed:
print("WebSocket connection closed")


if __name__ == "__main__":
asyncio.get_event_loop().run_until_complete(test_websocket_endpoint())
2 changes: 1 addition & 1 deletion notebooks/local_audio_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import requests

import requests

# filepath = "data/short_one_speaker.mp3"
# filepath = "data/24118946.mp3"
Expand Down
13 changes: 3 additions & 10 deletions notebooks/transcribe_endpoint_only.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,10 @@ def read_audio(
wav, sr = torchaudio.load(audio)
elif isinstance(audio, bytes):
with io.BytesIO(audio) as buffer:
wav, sr = sf.read(
buffer, format="RAW", channels=1, samplerate=16000, subtype="PCM_16"
)
wav, sr = sf.read(buffer, format="RAW", channels=1, samplerate=16000, subtype="PCM_16")
wav = torch.from_numpy(wav).unsqueeze(0)
else:
raise ValueError(
f"Invalid audio type. Must be either str or bytes, got: {type(audio)}."
)
raise ValueError(f"Invalid audio type. Must be either str or bytes, got: {type(audio)}.")

if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
Expand Down Expand Up @@ -88,7 +84,6 @@ class TranscribeRequest(BaseModel):


async def main():

audio, _ = read_audio("data/HL_Podcast_1.mp3")
ts = TensorShare.from_dict({"audio": audio}, backend=Backend.TORCH)

Expand All @@ -111,9 +106,7 @@ async def main():
headers={"Content-Type": "application/json"},
) as response:
if response.status != 200:
raise Exception(
f"Remote transcription failed with status {response.status}."
)
raise Exception(f"Remote transcription failed with status {response.status}.")
else:
r = await response.json()

Expand Down
1 change: 1 addition & 0 deletions notebooks/youtube_inference.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json

import requests

headers = {"accept": "application/json", "Content-Type": "application/json"}
Expand Down

0 comments on commit 033b008

Please sign in to comment.