From 32ed089a76efdd477406456697b1abbb4eadbaf9 Mon Sep 17 00:00:00 2001 From: Sasa Trivic Date: Thu, 1 Feb 2024 14:54:29 +0100 Subject: [PATCH] Change faster whisper to work with new extension --- whisper_live/server.py | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/whisper_live/server.py b/whisper_live/server.py index f2746084..e3498a60 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -155,7 +155,7 @@ def recv_audio(self, options["model"] = faster_whisper_custom_model_path client = ServeClientFasterWhisper( websocket, - multilingual=options["multilingual"], + multilingual=False, language=options["language"], task=options["task"], client_uid=options["uid"], @@ -585,13 +585,11 @@ def __init__( "tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2", "large-v3", ] - - self.multilingual = multilingual if not os.path.exists(model): - self.model_size_or_path = self.get_model_size(model) + self.model_size_or_path = self.check_valid_model(model) else: self.model_size_or_path = model - self.language = language if self.multilingual else "en" + self.language = "en" if self.model_size_or_path.endswith("en") else language self.task = task self.initial_prompt = initial_prompt self.vad_parameters = vad_parameters or {"threshold": 0.5} @@ -600,7 +598,7 @@ def __init__( if self.model_size_or_path == None: return - + self.transcriber = WhisperModel( self.model_size_or_path, device=device, @@ -620,9 +618,15 @@ def __init__( ) ) - def get_model_size(self, model_size): + def check_valid_model(self, model_size): """ - Returns the whisper model size based on multilingual. + Check if it's a valid whisper model size. + + Args: + model_size (str): The name of the model size to check. + + Returns: + str: The model size if valid, None otherwise. """ if model_size not in self.model_sizes: self.websocket.send( @@ -635,15 +639,6 @@ def get_model_size(self, model_size): ) ) return None - - if model_size.endswith("en") and self.multilingual: - logging.info(f"Setting multilingual to false with {model_size} which is english only model.") - self.multilingual = False - - if not model_size.endswith("en") and not self.multilingual: - logging.info(f"Setting multilingual to true with multilingual model {model_size}.") - self.multilingual = True - return model_size def speech_to_text(self):