Skip to content

Commit

Permalink
Change faster whisper to work with new extension
Browse files Browse the repository at this point in the history
  • Loading branch information
lightwastak3n committed Feb 1, 2024
1 parent 5b28dde commit 32ed089
Showing 1 changed file with 12 additions and 17 deletions.
29 changes: 12 additions & 17 deletions whisper_live/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def recv_audio(self,
options["model"] = faster_whisper_custom_model_path
client = ServeClientFasterWhisper(
websocket,
multilingual=options["multilingual"],
multilingual=False,
language=options["language"],
task=options["task"],
client_uid=options["uid"],
Expand Down Expand Up @@ -585,13 +585,11 @@ def __init__(
"tiny", "tiny.en", "base", "base.en", "small", "small.en",
"medium", "medium.en", "large-v2", "large-v3",
]

self.multilingual = multilingual
if not os.path.exists(model):
self.model_size_or_path = self.get_model_size(model)
self.model_size_or_path = self.check_valid_model(model)
else:
self.model_size_or_path = model
self.language = language if self.multilingual else "en"
self.language = "en" if self.model_size_or_path.endswith("en") else language
self.task = task
self.initial_prompt = initial_prompt
self.vad_parameters = vad_parameters or {"threshold": 0.5}
Expand All @@ -600,7 +598,7 @@ def __init__(

if self.model_size_or_path == None:
return

self.transcriber = WhisperModel(
self.model_size_or_path,
device=device,
Expand All @@ -620,9 +618,15 @@ def __init__(
)
)

def get_model_size(self, model_size):
def check_valid_model(self, model_size):
"""
Returns the whisper model size based on multilingual.
Check if it's a valid whisper model size.
Args:
model_size (str): The name of the model size to check.
Returns:
str: The model size if valid, None otherwise.
"""
if model_size not in self.model_sizes:
self.websocket.send(
Expand All @@ -635,15 +639,6 @@ def get_model_size(self, model_size):
)
)
return None

if model_size.endswith("en") and self.multilingual:
logging.info(f"Setting multilingual to false with {model_size} which is english only model.")
self.multilingual = False

if not model_size.endswith("en") and not self.multilingual:
logging.info(f"Setting multilingual to true with multilingual model {model_size}.")
self.multilingual = True

return model_size

def speech_to_text(self):
Expand Down

0 comments on commit 32ed089

Please sign in to comment.