Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non_Whisper Refiner expects different inference function #430

Open
tmoroney opened this issue Jan 24, 2025 · 1 comment
Open

Non_Whisper Refiner expects different inference function #430

tmoroney opened this issue Jan 24, 2025 · 1 comment

Comments

@tmoroney
Copy link

Hi there, I am trying to get the Refiner to work with MLX Whisper, but it seems to expect a different inference function to the one used for transcribing, a function that takes text_tokens as a parameter. I can't find an MLX Whisper function that takes text_tokens as input, so is it possible to avoid using a different inference function?

Currently I get this error:

stable_whisper/non_whisper/refinement.py", line 291, in get_prob
    token_probs: torch.Tensor = self.inference_func(audio_segment, text_tokens)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: inference() takes 1 positional argument but 2 were given

My code:

def inference(audio, **kwargs) -> dict:
    if kwargs["language"] == "auto":
        output = mlx_whisper.transcribe(
            audio,
            path_or_hf_repo=kwargs["model"],
            word_timestamps=True,
            verbose=True,
            task=kwargs["task"]
        )
    else:
        output = mlx_whisper.transcribe(
            audio,
            path_or_hf_repo=kwargs["model"],
            word_timestamps=True,
            language=kwargs["language"],
            verbose=True,
            task=kwargs["task"]
        )
    return stable_whisper.result.WhisperResult(output)

audio_file = "audio.wav"
kwargs = {
    "model": "mlx-community/whisper-small-mlx",
    "language": "en",
    "task": "transcribe"
}
result = stable_whisper.transcribe_any(
            inference, audio_file, inference_kwargs=kwargs, vad=False, regroup=True)
modifier = stable_whisper.non_whisper.Refiner(inference_func=inference)
tokenizer = mlx_whisper.tokenizer.get_encoding()
result = modifier.refine(audio=audio_file, result=result, encode=tokenizer.encode)
print(result)
@tmoroney tmoroney changed the title Refiner Expects different inference function Non_Whisper Refiner expects different inference function Jan 24, 2025
@jianfch
Copy link
Owner

jianfch commented Jan 25, 2025

Refinement adjusts the timestamps of transcribed words based how the confidence scores changes from changing the audio source. So it needs to be a function that takes in specific audio segments and words/tokens and output confidence scores for those words/tokens with respect to the audio segment.
It needs low level access to the model. So mlx_whisper.transcribe() will not work because it takes in audio and outputs different words and timestamps with different audio inputs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants