Skip to content

Commit

Permalink
Adding batch_size option to transcription call
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleks committed Apr 1, 2024
1 parent 5aa8f3a commit 0c6639e
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def parse_arguments():
parser.add_argument("--quantize_dir", type=str, default="quantize/1-gpu")
parser.add_argument("--dtype", type=str, default="float16", choices=["float16"])
parser.add_argument("--log_level", type=str, default="info")
parser.add_argument("--max_batch_size", type=int, default=24)
parser.add_argument("--max_batch_size", type=int, default=64)
parser.add_argument("--max_input_len", type=int, default=4)
parser.add_argument("--max_output_len", type=int, default=448)
parser.add_argument("--max_beam_width", type=int, default=1)
Expand Down
2 changes: 1 addition & 1 deletion src/wordcab_transcribe/engines/tensorrt_llm/trt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_session(self, engine_dir, runtime_mapping, debug_mode=False):

# TODO: Make dynamic max_batch_size and max_beam_width
decoder_model_config = ModelConfig(
max_batch_size=24,
max_batch_size=64,
max_beam_width=1,
num_heads=self.decoder_config["num_heads"],
num_kv_heads=self.decoder_config["num_heads"],
Expand Down
4 changes: 4 additions & 0 deletions src/wordcab_transcribe/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class Config:
},
],
"audio_duration": 2.678,
"batch_size": 1,
"offset_start": None,
"offset_end": None,
"num_speakers": -1,
Expand Down Expand Up @@ -176,6 +177,7 @@ class Config:
},
],
"audio_duration": 2.0,
"batch_size": 1,
"offset_start": None,
"offset_end": None,
"num_speakers": -1,
Expand Down Expand Up @@ -399,6 +401,7 @@ class BaseRequest(BaseModel):
offset_end: Union[float, None] = None
num_speakers: int = -1
diarization: bool = False
batch_size: int = 1
source_lang: str = "en"
timestamps: Timestamps = Timestamps.seconds
vocab: Union[List[str], None] = None
Expand Down Expand Up @@ -462,6 +465,7 @@ class Config:

json_schema_extra = {
"example": {
"batch_size": 1,
"offset_start": None,
"offset_end": None,
"num_speakers": -1,
Expand Down
4 changes: 4 additions & 0 deletions src/wordcab_transcribe/router/v1/audio_file_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
)
async def inference_with_audio( # noqa: C901
background_tasks: BackgroundTasks,
batch_size: Union[int, None] = Form(None), # noqa: B008
offset_start: Union[float, None] = Form(None), # noqa: B008
offset_end: Union[float, None] = Form(None), # noqa: B008
num_speakers: int = Form(-1), # noqa: B008
Expand Down Expand Up @@ -77,6 +78,7 @@ async def inference_with_audio( # noqa: C901
offset_end=offset_end,
num_speakers=num_speakers,
diarization=diarization,
batch_size=batch_size,
source_lang=source_lang,
timestamps=timestamps,
vocab=vocab,
Expand Down Expand Up @@ -115,6 +117,7 @@ async def inference_with_audio( # noqa: C901
offset_end=data.offset_end,
num_speakers=data.num_speakers,
diarization=data.diarization,
batch_size=data.batch_size,
multi_channel=data.multi_channel,
source_lang=data.source_lang,
timestamps_format=data.timestamps,
Expand Down Expand Up @@ -147,6 +150,7 @@ async def inference_with_audio( # noqa: C901
offset_end=data.offset_end,
num_speakers=data.num_speakers,
diarization=data.diarization,
batch_size=batch_size,
multi_channel=data.multi_channel,
source_lang=data.source_lang,
timestamps=data.timestamps,
Expand Down
2 changes: 2 additions & 0 deletions src/wordcab_transcribe/router/v1/audio_url_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ async def process_audio():
offset_end=data.offset_end,
num_speakers=data.num_speakers,
diarization=data.diarization,
batch_size=data.batch_size,
multi_channel=data.multi_channel,
source_lang=data.source_lang,
timestamps_format=data.timestamps,
Expand All @@ -130,6 +131,7 @@ async def process_audio():
offset_end=data.offset_end,
num_speakers=data.num_speakers,
diarization=data.diarization,
batch_size=data.batch_size,
multi_channel=data.multi_channel,
source_lang=data.source_lang,
timestamps=data.timestamps,
Expand Down
2 changes: 2 additions & 0 deletions src/wordcab_transcribe/router/v1/youtube_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ async def inference_with_youtube(
offset_end=data.offset_end,
num_speakers=data.num_speakers,
diarization=data.diarization,
batch_size=data.batch_size,
multi_channel=False,
source_lang=data.source_lang,
timestamps_format=data.timestamps,
Expand Down Expand Up @@ -90,6 +91,7 @@ async def inference_with_youtube(
offset_end=data.offset_end,
num_speakers=data.num_speakers,
diarization=data.diarization,
batch_size=data.batch_size,
source_lang=data.source_lang,
timestamps=data.timestamps,
vocab=data.vocab,
Expand Down
9 changes: 8 additions & 1 deletion src/wordcab_transcribe/services/asr_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class ASRTask(BaseModel):
url_type: Union[str, None]
diarization: "DiarizationTask"
duration: float
batch_size: int
multi_channel: bool
offset_start: Union[float, None]
post_processing: "PostProcessingTask"
Expand Down Expand Up @@ -278,7 +279,7 @@ def __init__(
self.local_services: LocalServiceRegistry = LocalServiceRegistry()
self.remote_services: RemoteServiceRegistry = RemoteServiceRegistry()
self.dual_channel_transcribe_options: dict = {
"beam_size": 5,
"beam_size": 1,
"patience": 1,
"length_penalty": 1,
"suppress_blank": False,
Expand Down Expand Up @@ -366,6 +367,7 @@ async def inference_warmup(self) -> None:
logger.info(f"Warmup GPU {gpu_index}.")
await self.process_input(
filepath=str(sample_path),
batch_size=1,
offset_start=None,
offset_end=None,
num_speakers=1,
Expand All @@ -386,6 +388,7 @@ async def inference_warmup(self) -> None:
async def process_input( # noqa: C901
self,
filepath: Union[str, List[str]],
batch_size: Union[int, None],
offset_start: Union[float, None],
offset_end: Union[float, None],
num_speakers: int,
Expand Down Expand Up @@ -415,6 +418,8 @@ async def process_input( # noqa: C901
Args:
filepath (Union[str, List[str]]):
Path to the audio file or list of paths to the audio files to process.
batch_size (Union[int, None]):
The batch size to use for the transcription. For tensorrt-llm whisper engine only.
offset_start (Union[float, None]):
The start time of the audio file to process.
offset_end (Union[float, None]):
Expand Down Expand Up @@ -502,6 +507,7 @@ async def process_input( # noqa: C901
execution=diarization_execution, num_speakers=num_speakers
),
duration=duration,
batch_size=batch_size,
multi_channel=multi_channel,
offset_start=offset_start,
post_processing=PostProcessingTask(),
Expand Down Expand Up @@ -1056,6 +1062,7 @@ async def process_input(
try:
result = self.transcription_service(
audio=data.audio,
batch_size=data.batch_size,
source_lang=data.source_lang,
model_index=gpu_index,
suppress_blank=False,
Expand Down
4 changes: 2 additions & 2 deletions src/wordcab_transcribe/services/transcribe_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __call__(
],
source_lang: str,
model_index: int,
batch_size: int = 24,
batch_size: int = 1,
num_beams: int = 1,
suppress_blank: bool = False,
vocab: Union[List[str], None] = None,
Expand All @@ -150,7 +150,7 @@ def __call__(
model_index (int):
Index of the model to use.
batch_size (int):
Batch size to use during generation.
Batch size to use during generation. Only used for tensorrt_llm engine.
num_beams (int):
Number of beams to use during generation.
suppress_blank (bool):
Expand Down

0 comments on commit 0c6639e

Please sign in to comment.