Skip to content

Commit

Permalink
added clip_timestamps to transcribe()
Browse files Browse the repository at this point in the history
-added parameter, `clip_timestamps`, to `transcribe()`

-fixed `AudioLoader` AttributeError with `_denoised_save_path`

-updated `AudioLoader` with ability to load only specific portions of the audio source

-updated  `AudioLoader` to provide more informative error messages when yt-dlp fails to load an URL

-removed redundant copying of encoder output, `audio_features`, in `DecodingTaskStable._get_audio_features()`

-changed `progress_callback` for `transcribe()` to only pass 2 positional arguments instead of 2 keyword arguments as intended (i.e. the parameters not longer need to be named `seek` and `total`)
  • Loading branch information
jianfch committed Nov 30, 2024
1 parent 9fe1bf5 commit 9fefdb8
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 13 deletions.
70 changes: 67 additions & 3 deletions stable_whisper/audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
import torch
import numpy as np
from typing import Union, Optional, Tuple
from typing import Union, Optional, Tuple, List

from .utils import (
is_ytdlp_available, load_source, load_audio, voice_freq_filter, get_samplerate, get_metadata
Expand Down Expand Up @@ -167,6 +167,8 @@ def __init__(
only_voice_freq: bool = False,
demucs: Optional[str] = None,
demucs_options: Optional[dict] = None,
load_sections: Optional[List[Tuple[float, Union[float, None]]]] = None,
negate_load: bool = False,
):
if stream and not isinstance(source, str):
raise NotImplementedError(f'``stream=True`` only supported for string ``source`` but got {type(source)}.')
Expand All @@ -175,6 +177,11 @@ def __init__(
from whisper.audio import SAMPLE_RATE
sr = SAMPLE_RATE
self._sr = sr
self.load_sections = (
self.negate_ts_sections(load_sections) if (negate_load and load_sections) else load_sections
)
self._curr_load_section_index = -1
self._curr_load_section_seeks = (0, 0)
if buffer_size is None:
buffer_size = (sr * 30)
self._buffer_size = self._valid_buffer_size(self.parse_chunk_size(buffer_size))
Expand Down Expand Up @@ -206,7 +213,12 @@ def __init__(
self._prev_unprep_samples = np.array([])
self._process = self._audio_loading_process()
if test_first_chunk and self.next_chunk(0) is None:
raise RuntimeError(f'FFmpeg failed to read "{source}".')
if self._extra_process is not None:
_, err = self._extra_process.communicate()
err = err.decode('utf-8', errors='ignore').strip('\n')
else:
err = f'FFmpeg failed to read "{source}".'
raise RuntimeError(err)

@property
def buffer_size(self):
Expand All @@ -228,6 +240,14 @@ def stream(self):
def prev_seek(self):
return self._prev_seek

@property
def curr_load_section_index(self):
return self._curr_load_section_index

@property
def curr_load_section_seeks(self):
return self._curr_load_section_seeks

@buffer_size.setter
def buffer_size(self, size: int):
self._buffer_size = self._valid_buffer_size(size)
Expand All @@ -239,6 +259,21 @@ def _valid_buffer_size(size: int):
raise ValueError('buffer size must be at least 0')
return size

@staticmethod
def negate_ts_sections(ts_sections: List[Tuple[Union[float], Union[float, None]]]) \
-> List[Tuple[Union[float], Union[float, None]]]:
new_sections = [(s0[1], s1[0]) for s0, s1 in zip(ts_sections[:-1], ts_sections[1:])]
new_sections.insert(0, (0.0, ts_sections[0][0]))
new_sections.append((ts_sections[-1][1], None))
new_sections = [s for s in new_sections if s[0] != s[1]]
return new_sections

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.terminate()

def parse_chunk_size(self, chunk_size: Union[int, str]) -> int:
if isinstance(chunk_size, int):
return chunk_size
Expand Down Expand Up @@ -376,6 +411,35 @@ def next_chunk(self, seek: int, size: Optional[int] = None) -> Union[torch.Tenso

return samples if len(samples) else None

def next_valid_chunk(self, seek: int, size: Optional[int] = None) -> Tuple[Union[torch.Tensor, None], int]:
if self.load_sections:
while (max_seek := self.curr_load_section_seeks[1]) is not None and seek + 1 >= max_seek:
if not self.skip_to_next_section():
return None, seek
if seek < self.curr_load_section_seeks[0]:
seek = self.curr_load_section_seeks[0]
chunk = self.next_chunk(seek, size=size)
if chunk is None:
return None, seek
size = chunk.size(-1)
max_seek = self.curr_load_section_seeks[1]
if max_seek is not None and seek + size > max_seek:
chunk = chunk[..., :max_seek - seek]
return chunk, seek
return self.next_chunk(seek, size=size), seek

def skip_to_next_section(self) -> bool:
if not self.load_sections or self.curr_load_section_index + 1 >= len(self.load_sections):
return False
self._curr_load_section_index += 1
start, end = self.load_sections[self._curr_load_section_index]
if start is not None:
start = round(start * self.sr)
if end is not None:
end = round(end * self.sr)
self._curr_load_section_seeks = (start, end)
return True

def _get_prep_func(self):

if self._denoiser:
Expand Down Expand Up @@ -480,7 +544,7 @@ def terminate(self):
self._extra_process.terminate()
if getattr(self, '_process', None) is not None and self._process.poll() is None:
self._process.terminate()
if getattr(self, '_denoised_save_path'):
if getattr(self, '_denoised_save_path', None):
self.save_denoised_audio()
if getattr(self, '_final_save_path', None):
self.save_final_audio()
Expand Down
6 changes: 2 additions & 4 deletions stable_whisper/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ def __init__(self, *args, **kwargs):

def _get_audio_features(self, mel: torch.Tensor):
if self.audio_features is None:
audio_features = super()._get_audio_features(mel)
self.audio_features = audio_features.detach().clone()
return audio_features
return self.audio_features.clone()
self.audio_features = super()._get_audio_features(mel)
return self.audio_features

# modified version of whisper.DecodingTask._main_loop
def _main_loop(self, audio_features: torch.Tensor, tokens: torch.Tensor):
Expand Down
30 changes: 24 additions & 6 deletions stable_whisper/whisper_word_level/original_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def transcribe_stable(
ignore_compatibility: bool = False,
extra_models: Optional[List["Whisper"]] = None,
dynamic_heads: Optional[Union[bool, int, str]] = None,
clip_timestamps: Optional[Union[str, List[float]]] = None,
**decode_options) \
-> WhisperResult:
"""
Expand Down Expand Up @@ -197,6 +198,9 @@ def transcribe_stable(
word-timestamp extraction. Specify the number of heads or `True` for default of 6 heads.
To specify number of iterations for finding the optimal heads,
use string with "," to separate heads and iterations (e.g. "8,3" for 8 heads and 3 iterations).
clip_timestamps : str or list of float
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
The last end timestamp defaults to the end of the file.
decode_options
Keyword arguments to construct class:`whisper.decode.DecodingOptions` instances.
Expand Down Expand Up @@ -266,6 +270,15 @@ def transcribe_stable(
denoiser, denoiser_options, demucs=demucs, demucs_options=demucs_options
)

if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
if clip_timestamps:
clip_timestamps = [clip_timestamps[i:i+2] for i in range(0, len(clip_timestamps), 2)]
if len(clip_timestamps[-1]) == 1:
clip_timestamps[-1] = [clip_timestamps[-1][0], None]

if isinstance(audio, AudioLoader):
audio.validate_external_args(
sr=SAMPLE_RATE,
Expand All @@ -275,6 +288,7 @@ def transcribe_stable(
denoiser_options=denoiser_options,
only_voice_freq=only_voice_freq
)
audio.load_sections = clip_timestamps
else:
denoiser_options = update_options(denoiser_options, device=device)
audio = AudioLoader(
Expand All @@ -285,7 +299,8 @@ def transcribe_stable(
only_voice_freq=only_voice_freq,
only_ffmpeg=only_ffmpeg,
verbose=verbose,
new_chunk_divisor=512 if vad else None
new_chunk_divisor=512 if vad else None,
load_sections=clip_timestamps
)
tokenizer = None
language = None
Expand Down Expand Up @@ -421,18 +436,18 @@ def new_segment(

with tqdm(total=initial_duration, unit='sec', disable=verbose is not False, desc=task.title()) as tqdm_pbar:

def update_pbar():
def update_pbar(new_total=None):
nonlocal audio_features
audio_features = None
curr_total_duration = audio.get_duration(2)
curr_total_duration = audio.get_duration(2) if new_total is None else new_total
if curr_total_duration != tqdm_pbar.total:
tqdm_pbar.total = curr_total_duration
tqdm_pbar.refresh()
seek_duration = min(curr_total_duration, round(seek_sample / SAMPLE_RATE, 2))
if not tqdm_pbar.disable:
tqdm_pbar.update(seek_duration - tqdm_pbar.n)
if progress_callback is not None:
progress_callback(seek=seek_duration, total=curr_total_duration)
progress_callback(seek_duration, curr_total_duration)

def update_seek():
nonlocal seek_sample
Expand All @@ -444,9 +459,12 @@ def fast_forward():
update_pbar()

while True:
audio_segment = audio.next_chunk(seek_sample, N_SAMPLES)
audio_segment, new_seek = audio.next_valid_chunk(seek_sample, N_SAMPLES)
if audio_segment is None:
break
if new_seek != seek_sample:
seek_sample = new_seek
update_pbar()
time_offset = seek_sample / SAMPLE_RATE
segment_samples = audio_segment.shape[-1]
segment_duration = segment_samples / SAMPLE_RATE
Expand Down Expand Up @@ -658,7 +676,7 @@ def fast_forward():
fast_forward()

# final update
update_pbar()
update_pbar(seek_sample / SAMPLE_RATE)

if model.device != torch.device('cpu'):
torch.cuda.empty_cache()
Expand Down

0 comments on commit 9fefdb8

Please sign in to comment.