Skip to content

Commit

Permalink
fix azure speech recognition
Browse files Browse the repository at this point in the history
  • Loading branch information
devxpy committed Feb 21, 2024
1 parent be300ec commit 967cdd8
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 107 deletions.
78 changes: 9 additions & 69 deletions daras_ai_v2/asr.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import json
import os.path
import subprocess
import os.path
import tempfile
from enum import Enum
from time import sleep

import langcodes
import requests
Expand All @@ -14,6 +12,7 @@
import gooey_ui as st
from daras_ai.image_input import upload_file_from_bytes, gs_url_to_uri
from daras_ai_v2 import settings
from daras_ai_v2.azure_asr import azure_asr
from daras_ai_v2.exceptions import (
raise_for_status,
UserError,
Expand Down Expand Up @@ -77,7 +76,6 @@
"so-SO", "sq-AL", "sr-RS", "sv-SE", "sw-KE", "sw-TZ", "ta-IN", "te-IN", "th-TH", "tr-TR", "uk-UA",
"ur-IN", "uz-UZ", "vi-VN", "wuu-CN", "yue-CN", "zh-CN", "zh-CN-shandong", "zh-CN-sichuan", "zh-HK",
"zh-TW", "zu-ZA"} # fmt: skip
MAX_POLLS = 100

# https://deepgram.com/product/languages for the "general" model:
# DEEPGRAM_SUPPORTED = {"nl","en","en-AU","en-US","en-GB","en-NZ","en-IN","fr","fr-CA","de","hi","hi-Latn","id","it","ja","ko","cmn-Hans-CN","cmn-Hant-TW","no","pl","pt","pt-PT","pt-BR","ru","es","es-419","sv","tr","uk"} # fmt: skip
Expand Down Expand Up @@ -165,7 +163,7 @@ def google_translate_language_selector(
label: the label to display
key: the key to save the selected language to in the session state
"""
languages = google_translate_languages()
languages = google_translate_target_languages()
options = list(languages.keys())
if allow_none:
options.insert(0, None)
Expand All @@ -178,8 +176,8 @@ def google_translate_language_selector(
)


@redis_cache_decorator
def google_translate_languages() -> dict[str, str]:
@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY)
def google_translate_target_languages() -> dict[str, str]:
"""
Get list of supported languages for Google Translate.
:return: Dictionary of language codes and display names.
Expand All @@ -199,8 +197,8 @@ def google_translate_languages() -> dict[str, str]:
}


@redis_cache_decorator
def google_translate_input_languages() -> dict[str, str]:
@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY)
def google_translate_source_languages() -> dict[str, str]:
"""
Get list of supported languages for Google Translate.
:return: Dictionary of language codes and display names.
Expand Down Expand Up @@ -283,11 +281,11 @@ def run_google_translate(
if source_language:
source_language = langcodes.Language.get(source_language).to_tag()
source_language = get_language_in_collection(
source_language, google_translate_input_languages().keys()
source_language, google_translate_source_languages().keys()
) # this will default to autodetect if language is not found as supported
target_language = langcodes.Language.get(target_language).to_tag()
target_language: str | None = get_language_in_collection(
target_language, google_translate_languages().keys()
target_language, google_translate_target_languages().keys()
)
if not target_language:
raise ValueError(f"Unsupported target language: {target_language!r}")
Expand Down Expand Up @@ -648,64 +646,6 @@ def _get_or_create_recognizer(
return recognizer


def azure_asr(audio_url: str, language: str):
# transcription from audio url only supported via rest api or cli
# Start by initializing a request
payload = {
"contentUrls": [
audio_url,
],
"displayName": "Gooey Transcription",
"model": None,
"properties": {
"wordLevelTimestampsEnabled": False,
},
"locale": language or "en-US",
}
r = requests.post(
str(furl(settings.AZURE_SPEECH_ENDPOINT) / "speechtotext/v3.1/transcriptions"),
headers={
"Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY,
"Content-Type": "application/json",
},
json=payload,
)
raise_for_status(r)
uri = r.json()["self"]

# poll for results
for _ in range(MAX_POLLS):
r = requests.get(
uri,
headers={
"Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY,
},
)
if not r.ok or not r.json()["status"] == "Succeeded":
sleep(5)
continue
r = requests.get(
r.json()["links"]["files"],
headers={
"Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY,
},
)
raise_for_status(r)
transcriptions = []
for value in r.json()["values"]:
if value["kind"] != "Transcription":
continue
r = requests.get(
value["links"]["contentUrl"],
headers={"Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY},
)
raise_for_status(r)
combined_phrases = r.json().get("combinedRecognizedPhrases") or [{}]
transcriptions += [combined_phrases[0].get("display", "")]
return "\n".join(transcriptions)
assert False, "Max polls exceeded, Azure speech did not yield a response"


# 16kHz, 16-bit, mono
FFMPEG_WAV_ARGS = ["-vn", "-acodec", "pcm_s16le", "-ac", "1", "-ar", "16000"]

Expand Down
96 changes: 96 additions & 0 deletions daras_ai_v2/azure_asr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import datetime
from time import sleep

import requests
from furl import furl

from daras_ai_v2 import settings
from daras_ai_v2.exceptions import (
raise_for_status,
)
from daras_ai_v2.redis_cache import redis_cache_decorator

# 20 mins timeout
MAX_POLLS = 200
POLL_INTERVAL = 6


def azure_asr(audio_url: str, language: str):
# Start by initializing a request
# https://eastus.dev.cognitive.microsoft.com/docs/services/speech-to-text-api-v3-1/operations/Transcriptions_Create
language = language or "en-US"
r = requests.post(
str(furl(settings.AZURE_SPEECH_ENDPOINT) / "speechtotext/v3.1/transcriptions"),
headers=azure_auth_header(),
json={
"contentUrls": [audio_url],
"displayName": f"Gooey Transcription {datetime.datetime.now().isoformat()} {language=} {audio_url=}",
"model": azure_get_latest_model(language),
"properties": {
"wordLevelTimestampsEnabled": False,
# "displayFormWordLevelTimestampsEnabled": True,
# "diarizationEnabled": False,
# "punctuationMode": "DictatedAndAutomatic",
# "profanityFilterMode": "Masked",
},
"locale": language,
},
)
raise_for_status(r)
uri = r.json()["self"]

# poll for results
for _ in range(MAX_POLLS):
r = requests.get(uri, headers=azure_auth_header())
if not r.ok or not r.json()["status"] == "Succeeded":
sleep(POLL_INTERVAL)
continue
r = requests.get(r.json()["links"]["files"], headers=azure_auth_header())
raise_for_status(r)
transcriptions = []
for value in r.json()["values"]:
if value["kind"] != "Transcription":
continue
r = requests.get(value["links"]["contentUrl"], headers=azure_auth_header())
raise_for_status(r)
combined_phrases = r.json().get("combinedRecognizedPhrases") or [{}]
transcriptions += [combined_phrases[0].get("display", "")]
return "\n".join(transcriptions)

raise RuntimeError("Max polls exceeded, Azure speech did not yield a response")


@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY)
def azure_get_latest_model(language: str) -> dict | None:
# https://eastus.dev.cognitive.microsoft.com/docs/services/speech-to-text-api-v3-1/operations/Models_ListBaseModels
r = requests.get(
str(furl(settings.AZURE_SPEECH_ENDPOINT) / "speechtotext/v3.1/models/base"),
headers=azure_auth_header(),
params={"filter": f"locale eq '{language}'"},
)
raise_for_status(r)
data = r.json()["values"]
try:
models = sorted(
data,
key=lambda m: datetime.datetime.strptime(
m["createdDateTime"], "%Y-%m-%dT%H:%M:%SZ"
),
reverse=True,
)
# ignore date parsing errors
except ValueError:
models = data
models.reverse()
for model in models:
if "whisper" in model["displayName"].lower():
# whisper is pretty slow on azure, so we ignore it
continue
# return the latest model
return {"self": model["self"]}


def azure_auth_header():
return {
"Ocp-Apim-Subscription-Key": settings.AZURE_SPEECH_KEY,
}
4 changes: 2 additions & 2 deletions daras_ai_v2/azure_doc_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def azure_doc_extract_pages(
]


@redis_cache_decorator
@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY)
def azure_form_recognizer_models() -> dict[str, str]:
r = requests.get(
str(
Expand All @@ -40,7 +40,7 @@ def azure_form_recognizer_models() -> dict[str, str]:
return {value["modelId"]: value["description"] for value in r.json()["value"]}


@redis_cache_decorator
@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY)
def azure_form_recognizer(url: str, model_id: str, params: dict = None):
r = requests.post(
str(
Expand Down
5 changes: 3 additions & 2 deletions daras_ai_v2/glossary.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from daras_ai_v2.asr import google_translate_languages
from daras_ai_v2.asr import google_translate_target_languages

from daras_ai_v2.doc_search_settings_widgets import document_uploader

Expand Down Expand Up @@ -125,7 +125,8 @@ def get_langcodes_from_df(df: "pd.DataFrame") -> list[str]:
import langcodes

supported = {
langcodes.Language.get(code).language for code in google_translate_languages()
langcodes.Language.get(code).language
for code in google_translate_target_languages()
}
ret = []
for col in df.columns:
Expand Down
71 changes: 38 additions & 33 deletions daras_ai_v2/redis_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from daras_ai_v2 import settings


LOCK_TIMEOUT_SEC = 10 * 60


Expand All @@ -20,38 +19,44 @@ def get_redis_cache():
F = typing.TypeVar("F", bound=typing.Callable[..., typing.Any])


def redis_cache_decorator(fn: F) -> F:
@wraps(fn)
def wrapper(*args, **kwargs):
# hash the args and kwargs so they are not too long
args_hash = hashlib.sha256(f"{args}{kwargs}".encode()).hexdigest()
# create a readable cache key
cache_key = f"gooey/redis-cache-decorator/v1/{fn.__name__}/{args_hash}"
# get the redis cache
redis_cache = get_redis_cache()
# lock the cache key so that only one thread can run the function
lock = redis_cache.lock(
name=os.path.join(cache_key, "lock"), timeout=LOCK_TIMEOUT_SEC
)
try:
lock.acquire()
except redis.exceptions.LockError:
pass
try:
cache_val = redis_cache.get(cache_key)
# if the cache exists, return it
if cache_val:
return pickle.loads(cache_val)
# otherwise, run the function and cache the result
else:
result = fn(*args, **kwargs)
cache_val = pickle.dumps(result)
redis_cache.set(cache_key, cache_val)
return result
finally:
def redis_cache_decorator(fn: F = None, ex=None) -> F:
def decorator(fn: F) -> F:
@wraps(fn)
def wrapper(*args, **kwargs):
# hash the args and kwargs so they are not too long
args_hash = hashlib.sha256(f"{args}{kwargs}".encode()).hexdigest()
# create a readable cache key
cache_key = f"gooey/redis-cache-decorator/v1/{fn.__name__}/{args_hash}"
# get the redis cache
redis_cache = get_redis_cache()
# lock the cache key so that only one thread can run the function
lock = redis_cache.lock(
name=os.path.join(cache_key, "lock"), timeout=LOCK_TIMEOUT_SEC
)
try:
lock.release()
lock.acquire()
except redis.exceptions.LockError:
pass

return wrapper
try:
cache_val = redis_cache.get(cache_key)
# if the cache exists, return it
if cache_val:
return pickle.loads(cache_val)
# otherwise, run the function and cache the result
else:
result = fn(*args, **kwargs)
cache_val = pickle.dumps(result)
redis_cache.set(cache_key, cache_val, ex=ex)
return result
finally:
try:
lock.release()
except redis.exceptions.LockError:
pass

return wrapper

if fn is None:
return decorator
else:
return decorator(fn)
2 changes: 2 additions & 0 deletions daras_ai_v2/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,8 @@
REDIS_CACHE_URL = config("REDIS_CACHE_URL", "redis://localhost:6379")
TWITTER_BEARER_TOKEN = config("TWITTER_BEARER_TOKEN", None)

REDIS_MODELS_CACHE_EXPIRY = 60 * 60 * 24 * 7

GPU_CELERY_BROKER_URL = config("GPU_CELERY_BROKER_URL", "amqp://localhost:5674")
GPU_CELERY_RESULT_BACKEND = config(
"GPU_CELERY_RESULT_BACKEND", "redis://localhost:6374"
Expand Down
3 changes: 2 additions & 1 deletion daras_ai_v2/text_to_speech_settings_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from google.cloud import texttospeech

import gooey_ui as st
from daras_ai_v2 import settings
from daras_ai_v2.enum_selector_widget import enum_selector
from daras_ai_v2.exceptions import raise_for_status
from daras_ai_v2.redis_cache import redis_cache_decorator
Expand Down Expand Up @@ -380,7 +381,7 @@ def text_to_speech_settings(page):
)


@redis_cache_decorator
@redis_cache_decorator(ex=settings.REDIS_MODELS_CACHE_EXPIRY)
def google_tts_voices() -> dict[str, str]:
voices: list[texttospeech.Voice] = (
texttospeech.TextToSpeechClient().list_voices().voices
Expand Down

0 comments on commit 967cdd8

Please sign in to comment.