From 967cdd8bb1229891f4c1e05d178d250b514749af Mon Sep 17 00:00:00 2001 From: Dev Aggarwal Date: Wed, 21 Feb 2024 22:01:09 +0300 Subject: [PATCH] fix azure speech recognition --- daras_ai_v2/asr.py | 78 ++------------- daras_ai_v2/azure_asr.py | 96 +++++++++++++++++++ daras_ai_v2/azure_doc_extract.py | 4 +- daras_ai_v2/glossary.py | 5 +- daras_ai_v2/redis_cache.py | 71 +++++++------- daras_ai_v2/settings.py | 2 + .../text_to_speech_settings_widgets.py | 3 +- 7 files changed, 152 insertions(+), 107 deletions(-) create mode 100644 daras_ai_v2/azure_asr.py diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index e37fb39b0..cc888d61e 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -1,9 +1,7 @@ -import json import os.path -import subprocess +import os.path import tempfile from enum import Enum -from time import sleep import langcodes import requests @@ -14,6 +12,7 @@ import gooey_ui as st from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri from daras_ai_v2 import settings +from daras_ai_v2.azure_asr import azure_asr from daras_ai_v2.exceptions import ( raise_for_status, UserError, @@ -77,7 +76,6 @@ "so-SO", "sq-AL", "sr-RS", "sv-SE", "sw-KE", "sw-TZ", "ta-IN", "te-IN", "th-TH", "tr-TR", "uk-UA", "ur-IN", "uz-UZ", "vi-VN", "wuu-CN", "yue-CN", "zh-CN", "zh-CN-shandong", "zh-CN-sichuan", "zh-HK", "zh-TW", "zu-ZA"} # fmt: skip -MAX_POLLS = 100 # https://deepgram.com/product/languages for the "general" model: # DEEPGRAM_SUPPORTED = {"nl","en","en-AU","en-US","en-GB","en-NZ","en-IN","fr","fr-CA","de","hi","hi-Latn","id","it","ja","ko","cmn-Hans-CN","cmn-Hant-TW","no","pl","pt","pt-PT","pt-BR","ru","es","es-419","sv","tr","uk"} # fmt: skip @@ -165,7 +163,7 @@ def google_translate_language_selector( label: the label to display key: the key to save the selected language to in the session state """ - languages = google_translate_languages() + languages = google_translate_target_languages() options = list(languages.keys()) if allow_none: options.insert(0, None) @@ -178,8 +176,8 @@ def google_translate_language_selector( ) -@redis_cache_decorator -def google_translate_languages() -> dict[str, str]: +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) +def google_translate_target_languages() -> dict[str, str]: """ Get list of supported languages for Google Translate. :return: Dictionary of language codes and display names. @@ -199,8 +197,8 @@ def google_translate_languages() -> dict[str, str]: } -@redis_cache_decorator -def google_translate_input_languages() -> dict[str, str]: +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) +def google_translate_source_languages() -> dict[str, str]: """ Get list of supported languages for Google Translate. :return: Dictionary of language codes and display names. @@ -283,11 +281,11 @@ def run_google_translate( if source_language: source_language = langcodes.Language.get(source_language).to_tag() source_language = get_language_in_collection( - source_language, google_translate_input_languages().keys() + source_language, google_translate_source_languages().keys() ) # this will default to autodetect if language is not found as supported target_language = langcodes.Language.get(target_language).to_tag() target_language: str | None = get_language_in_collection( - target_language, google_translate_languages().keys() + target_language, google_translate_target_languages().keys() ) if not target_language: raise ValueError(f"Unsupported target language: {target_language!r}") @@ -648,64 +646,6 @@ def _get_or_create_recognizer( return recognizer -def azure_asr(audio_url: str, language: str): - # transcription from audio url only supported via rest api or cli - # Start by initializing a request - payload = { - "contentUrls": [ - audio_url, - ], - "displayName": "Gooey Transcription", - "model": None, - "properties": { - "wordLevelTimestampsEnabled": False, - }, - "locale": language or "en-US", - } - r = requests.post( - str(furl(settings.AZURE_SPEECH_ENDPOINT) / "speechtotext/v3.1/transcriptions"), - headers={ - "Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY, - "Content-Type": "application/json", - }, - json=payload, - ) - raise_for_status(r) - uri = r.json()["self"] - - # poll for results - for _ in range(MAX_POLLS): - r = requests.get( - uri, - headers={ - "Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY, - }, - ) - if not r.ok or not r.json()["status"] == "Succeeded": - sleep(5) - continue - r = requests.get( - r.json()["links"]["files"], - headers={ - "Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY, - }, - ) - raise_for_status(r) - transcriptions = [] - for value in r.json()["values"]: - if value["kind"] != "Transcription": - continue - r = requests.get( - value["links"]["contentUrl"], - headers={"Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY}, - ) - raise_for_status(r) - combined_phrases = r.json().get("combinedRecognizedPhrases") or [{}] - transcriptions += [combined_phrases[0].get("display", "")] - return "\n".join(transcriptions) - assert False, "Max polls exceeded, Azure speech did not yield a response" - - # 16kHz, 16-bit, mono FFMPEG_WAV_ARGS = ["-vn", "-acodec", "pcm_s16le", "-ac", "1", "-ar", "16000"] diff --git a/daras_ai_v2/azure_asr.py b/daras_ai_v2/azure_asr.py new file mode 100644 index 000000000..aed873b03 --- /dev/null +++ b/daras_ai_v2/azure_asr.py @@ -0,0 +1,96 @@ +import datetime +from time import sleep + +import requests +from furl import furl + +from daras_ai_v2 import settings +from daras_ai_v2.exceptions import ( + raise_for_status, +) +from daras_ai_v2.redis_cache import redis_cache_decorator + +# 20 mins timeout +MAX_POLLS = 200 +POLL_INTERVAL = 6 + + +def azure_asr(audio_url: str, language: str): + # Start by initializing a request + # https://eastus.dev.cognitive.microsoft.com/docs/services/speech-to-text-api-v3-1/operations/Transcriptions_Create + language = language or "en-US" + r = requests.post( + str(furl(settings.AZURE_SPEECH_ENDPOINT) / "speechtotext/v3.1/transcriptions"), + headers=azure_auth_header(), + json={ + "contentUrls": [audio_url], + "displayName": f"Gooey Transcription {datetime.datetime.now().isoformat()} {language=} {audio_url=}", + "model": azure_get_latest_model(language), + "properties": { + "wordLevelTimestampsEnabled": False, + # "displayFormWordLevelTimestampsEnabled": True, + # "diarizationEnabled": False, + # "punctuationMode": "DictatedAndAutomatic", + # "profanityFilterMode": "Masked", + }, + "locale": language, + }, + ) + raise_for_status(r) + uri = r.json()["self"] + + # poll for results + for _ in range(MAX_POLLS): + r = requests.get(uri, headers=azure_auth_header()) + if not r.ok or not r.json()["status"] == "Succeeded": + sleep(POLL_INTERVAL) + continue + r = requests.get(r.json()["links"]["files"], headers=azure_auth_header()) + raise_for_status(r) + transcriptions = [] + for value in r.json()["values"]: + if value["kind"] != "Transcription": + continue + r = requests.get(value["links"]["contentUrl"], headers=azure_auth_header()) + raise_for_status(r) + combined_phrases = r.json().get("combinedRecognizedPhrases") or [{}] + transcriptions += [combined_phrases[0].get("display", "")] + return "\n".join(transcriptions) + + raise RuntimeError("Max polls exceeded, Azure speech did not yield a response") + + +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) +def azure_get_latest_model(language: str) -> dict | None: + # https://eastus.dev.cognitive.microsoft.com/docs/services/speech-to-text-api-v3-1/operations/Models_ListBaseModels + r = requests.get( + str(furl(settings.AZURE_SPEECH_ENDPOINT) / "speechtotext/v3.1/models/base"), + headers=azure_auth_header(), + params={"filter": f"locale eq '{language}'"}, + ) + raise_for_status(r) + data = r.json()["values"] + try: + models = sorted( + data, + key=lambda m: datetime.datetime.strptime( + m["createdDateTime"], "%Y-%m-%dT%H:%M:%SZ" + ), + reverse=True, + ) + # ignore date parsing errors + except ValueError: + models = data + models.reverse() + for model in models: + if "whisper" in model["displayName"].lower(): + # whisper is pretty slow on azure, so we ignore it + continue + # return the latest model + return {"self": model["self"]} + + +def azure_auth_header(): + return { + "Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY, + } diff --git a/daras_ai_v2/azure_doc_extract.py b/daras_ai_v2/azure_doc_extract.py index 173bb080d..878dc5733 100644 --- a/daras_ai_v2/azure_doc_extract.py +++ b/daras_ai_v2/azure_doc_extract.py @@ -26,7 +26,7 @@ def azure_doc_extract_pages( ] -@redis_cache_decorator +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) def azure_form_recognizer_models() -> dict[str, str]: r = requests.get( str( @@ -40,7 +40,7 @@ def azure_form_recognizer_models() -> dict[str, str]: return {value["modelId"]: value["description"] for value in r.json()["value"]} -@redis_cache_decorator +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) def azure_form_recognizer(url: str, model_id: str, params: dict = None): r = requests.post( str( diff --git a/daras_ai_v2/glossary.py b/daras_ai_v2/glossary.py index 2e56352da..5444135c1 100644 --- a/daras_ai_v2/glossary.py +++ b/daras_ai_v2/glossary.py @@ -1,4 +1,4 @@ -from daras_ai_v2.asr import google_translate_languages +from daras_ai_v2.asr import google_translate_target_languages from daras_ai_v2.doc_search_settings_widgets import document_uploader @@ -125,7 +125,8 @@ def get_langcodes_from_df(df: "pd.DataFrame") -> list[str]: import langcodes supported = { - langcodes.Language.get(code).language for code in google_translate_languages() + langcodes.Language.get(code).language + for code in google_translate_target_languages() } ret = [] for col in df.columns: diff --git a/daras_ai_v2/redis_cache.py b/daras_ai_v2/redis_cache.py index 4930e5427..f339bec80 100644 --- a/daras_ai_v2/redis_cache.py +++ b/daras_ai_v2/redis_cache.py @@ -8,7 +8,6 @@ from daras_ai_v2 import settings - LOCK_TIMEOUT_SEC = 10 * 60 @@ -20,38 +19,44 @@ def get_redis_cache(): F = typing.TypeVar("F", bound=typing.Callable[..., typing.Any]) -def redis_cache_decorator(fn: F) -> F: - @wraps(fn) - def wrapper(*args, **kwargs): - # hash the args and kwargs so they are not too long - args_hash = hashlib.sha256(f"{args}{kwargs}".encode()).hexdigest() - # create a readable cache key - cache_key = f"gooey/redis-cache-decorator/v1/{fn.__name__}/{args_hash}" - # get the redis cache - redis_cache = get_redis_cache() - # lock the cache key so that only one thread can run the function - lock = redis_cache.lock( - name=os.path.join(cache_key, "lock"), timeout=LOCK_TIMEOUT_SEC - ) - try: - lock.acquire() - except redis.exceptions.LockError: - pass - try: - cache_val = redis_cache.get(cache_key) - # if the cache exists, return it - if cache_val: - return pickle.loads(cache_val) - # otherwise, run the function and cache the result - else: - result = fn(*args, **kwargs) - cache_val = pickle.dumps(result) - redis_cache.set(cache_key, cache_val) - return result - finally: +def redis_cache_decorator(fn: F = None, ex=None) -> F: + def decorator(fn: F) -> F: + @wraps(fn) + def wrapper(*args, **kwargs): + # hash the args and kwargs so they are not too long + args_hash = hashlib.sha256(f"{args}{kwargs}".encode()).hexdigest() + # create a readable cache key + cache_key = f"gooey/redis-cache-decorator/v1/{fn.__name__}/{args_hash}" + # get the redis cache + redis_cache = get_redis_cache() + # lock the cache key so that only one thread can run the function + lock = redis_cache.lock( + name=os.path.join(cache_key, "lock"), timeout=LOCK_TIMEOUT_SEC + ) try: - lock.release() + lock.acquire() except redis.exceptions.LockError: pass - - return wrapper + try: + cache_val = redis_cache.get(cache_key) + # if the cache exists, return it + if cache_val: + return pickle.loads(cache_val) + # otherwise, run the function and cache the result + else: + result = fn(*args, **kwargs) + cache_val = pickle.dumps(result) + redis_cache.set(cache_key, cache_val, ex=ex) + return result + finally: + try: + lock.release() + except redis.exceptions.LockError: + pass + + return wrapper + + if fn is None: + return decorator + else: + return decorator(fn) diff --git a/daras_ai_v2/settings.py b/daras_ai_v2/settings.py index 512fb3834..d5d2741ae 100644 --- a/daras_ai_v2/settings.py +++ b/daras_ai_v2/settings.py @@ -293,6 +293,8 @@ REDIS_CACHE_URL = config("REDIS_CACHE_URL", "redis://localhost:6379") TWITTER_BEARER_TOKEN = config("TWITTER_BEARER_TOKEN", None) +REDIS_MODELS_CACHE_EXPIRY = 60 * 60 * 24 * 7 + GPU_CELERY_BROKER_URL = config("GPU_CELERY_BROKER_URL", "amqp://localhost:5674") GPU_CELERY_RESULT_BACKEND = config( "GPU_CELERY_RESULT_BACKEND", "redis://localhost:6374" diff --git a/daras_ai_v2/text_to_speech_settings_widgets.py b/daras_ai_v2/text_to_speech_settings_widgets.py index 499b502d2..f110126d2 100644 --- a/daras_ai_v2/text_to_speech_settings_widgets.py +++ b/daras_ai_v2/text_to_speech_settings_widgets.py @@ -4,6 +4,7 @@ from google.cloud import texttospeech import gooey_ui as st +from daras_ai_v2 import settings from daras_ai_v2.enum_selector_widget import enum_selector from daras_ai_v2.exceptions import raise_for_status from daras_ai_v2.redis_cache import redis_cache_decorator @@ -380,7 +381,7 @@ def text_to_speech_settings(page): ) -@redis_cache_decorator +@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY) def google_tts_voices() -> dict[str, str]: voices: list[texttospeech.Voice] = ( texttospeech.TextToSpeechClient().list_voices().voices