-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Text2Speech: Keeping speech in memory #11
base: master
Are you sure you want to change the base?
Changes from 13 commits
d5af48d
ffc6e8d
ed8e86a
dcb2a20
eb24bfa
dc44100
2bc6e38
c79e399
eb9ebc7
34e1b7b
18bd6e9
069f4c2
6ec0689
befb02d
8bfcb84
b525844
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import tempfile | ||
from pathlib import Path | ||
|
||
|
||
def save_audio(audio: bytes) -> str: | ||
"""Save audio to a temporary file and return the path.""" | ||
with tempfile.NamedTemporaryFile(mode="bx", suffix=".wav", delete=False) as f: | ||
f.write(audio) | ||
return f.name | ||
|
||
|
||
def load_audio(audio_file_path: str) -> bytes: | ||
"""Load audio from a file into bytes.""" | ||
if Path(audio_file_path).exists(): | ||
mateusz-wosinski-ds marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with open(audio_file_path, mode="rb") as f: | ||
audio = f.read() | ||
return audio | ||
raise FileNotFoundError(f"File {audio_file_path} not found.") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,28 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
import tempfile | ||
from typing import Any, Dict, Optional | ||
|
||
from IPython import display | ||
|
||
from langchain.callbacks.manager import CallbackManagerForToolRun | ||
from langchain.pydantic_v1 import root_validator | ||
from langchain.tools.audio_utils import load_audio, save_audio | ||
from langchain.tools.base import BaseTool | ||
from langchain.utils import get_from_dict_or_env | ||
|
||
|
||
def _import_azure_speech() -> Any: | ||
try: | ||
import azure.cognitiveservices.speech as speechsdk | ||
except ImportError: | ||
raise ImportError( | ||
"azure.cognitiveservices.speech is not installed. " | ||
"Run `pip install azure-cognitiveservices-speech` to install." | ||
) | ||
return speechsdk | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
|
@@ -33,6 +47,7 @@ class AzureCogsText2SpeechTool(BaseTool): | |
@root_validator(pre=True) | ||
def validate_environment(cls, values: Dict) -> Dict: | ||
"""Validate that api key and endpoint exists in environment.""" | ||
speechsdk = _import_azure_speech() | ||
azure_cogs_key = get_from_dict_or_env( | ||
values, "azure_cogs_key", "AZURE_COGS_KEY" | ||
) | ||
|
@@ -41,40 +56,21 @@ def validate_environment(cls, values: Dict) -> Dict: | |
values, "azure_cogs_region", "AZURE_COGS_REGION" | ||
) | ||
|
||
try: | ||
import azure.cognitiveservices.speech as speechsdk | ||
|
||
values["speech_config"] = speechsdk.SpeechConfig( | ||
subscription=azure_cogs_key, region=azure_cogs_region | ||
) | ||
except ImportError: | ||
raise ImportError( | ||
"azure-cognitiveservices-speech is not installed. " | ||
"Run `pip install azure-cognitiveservices-speech` to install." | ||
) | ||
|
||
values["speech_config"] = speechsdk.SpeechConfig( | ||
subscription=azure_cogs_key, region=azure_cogs_region | ||
) | ||
return values | ||
|
||
def _text2speech(self, text: str, speech_language: str) -> str: | ||
try: | ||
import azure.cognitiveservices.speech as speechsdk | ||
except ImportError: | ||
pass | ||
|
||
def _text2speech(self, text: str, speech_language: str) -> bytes: | ||
speechsdk = _import_azure_speech() | ||
self.speech_config.speech_synthesis_language = speech_language | ||
speech_synthesizer = speechsdk.SpeechSynthesizer( | ||
speech_config=self.speech_config, audio_config=None | ||
) | ||
result = speech_synthesizer.speak_text(text) | ||
|
||
if result.reason == speechsdk.ResultReason.SynthesizingAudioCompleted: | ||
stream = speechsdk.AudioDataStream(result) | ||
with tempfile.NamedTemporaryFile( | ||
mode="wb", suffix=".wav", delete=False | ||
) as f: | ||
stream.save_to_wav_file(f.name) | ||
|
||
return f.name | ||
return result.audio_data | ||
|
||
elif result.reason == speechsdk.ResultReason.Canceled: | ||
cancellation_details = result.cancellation_details | ||
|
@@ -84,10 +80,10 @@ def _text2speech(self, text: str, speech_language: str) -> str: | |
f"Speech synthesis error: {cancellation_details.error_details}" | ||
) | ||
|
||
return "Speech synthesis canceled." | ||
raise RuntimeError("Speech synthesis canceled.") | ||
|
||
else: | ||
return f"Speech synthesis failed: {result.reason}" | ||
raise RuntimeError(f"Speech synthesis failed: {result.reason}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aren't you breaking agents by those changes? |
||
|
||
def _run( | ||
self, | ||
|
@@ -96,7 +92,24 @@ def _run( | |
) -> str: | ||
"""Use the tool.""" | ||
try: | ||
speech_file = self._text2speech(query, self.speech_language) | ||
return speech_file | ||
speech = self._text2speech(query, self.speech_language) | ||
self.play(speech) | ||
return "Speech has been generated" | ||
except Exception as e: | ||
raise RuntimeError(f"Error while running AzureCogsText2SpeechTool: {e}") | ||
|
||
def play(self, speech: bytes) -> None: | ||
"""Play the speech.""" | ||
audio = display.Audio(speech) | ||
display.display(audio) | ||
|
||
def generate_and_save(self, query: str) -> str: | ||
"""Save the text as speech to a temporary file.""" | ||
speech = self._text2speech(query, self.speech_language) | ||
path = save_audio(speech) | ||
return path | ||
|
||
def load_and_play(self, path: str) -> None: | ||
"""Load the text as speech from a temporary file and play it.""" | ||
speech = load_audio(path) | ||
self.play(speech) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
import tempfile | ||
from enum import Enum | ||
from typing import Any, Dict, Optional, Union | ||
|
||
from langchain.callbacks.manager import CallbackManagerForToolRun | ||
from langchain.pydantic_v1 import root_validator | ||
from langchain.tools.audio_utils import load_audio, save_audio | ||
from langchain.tools.base import BaseTool | ||
from langchain.utils import get_from_dict_or_env | ||
|
||
|
@@ -56,20 +56,14 @@ def _run( | |
elevenlabs = _import_elevenlabs() | ||
try: | ||
speech = elevenlabs.generate(text=query, model=self.model) | ||
with tempfile.NamedTemporaryFile( | ||
mode="bx", suffix=".wav", delete=False | ||
) as f: | ||
f.write(speech) | ||
return f.name | ||
self.play(speech) | ||
return "Speech has been generated" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think the best usage would be to have a single tool which can have different implementation provided, similar to PythonREPLTool. |
||
except Exception as e: | ||
raise RuntimeError(f"Error while running ElevenLabsText2SpeechTool: {e}") | ||
|
||
def play(self, speech_file: str) -> None: | ||
def play(self, speech: bytes) -> None: | ||
"""Play the text as speech.""" | ||
elevenlabs = _import_elevenlabs() | ||
with open(speech_file, mode="rb") as f: | ||
speech = f.read() | ||
|
||
elevenlabs.play(speech) | ||
|
||
def stream_speech(self, query: str) -> None: | ||
|
@@ -78,3 +72,15 @@ def stream_speech(self, query: str) -> None: | |
elevenlabs = _import_elevenlabs() | ||
speech_stream = elevenlabs.generate(text=query, model=self.model, stream=True) | ||
elevenlabs.stream(speech_stream) | ||
|
||
def generate_and_save(self, query: str) -> str: | ||
"""Save the text as speech to a temporary file.""" | ||
elevenlabs = _import_elevenlabs() | ||
speech = elevenlabs.generate(text=query, model=self.model) | ||
path = save_audio(speech) | ||
return path | ||
|
||
def load_and_play(self, path: str) -> None: | ||
"""Load the text as speech from a temporary file and play it.""" | ||
speech = load_audio(path) | ||
self.play(speech) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems bizzare for me, to have this function and use namedTemporaryFile, why it cannot be user provied path?
On the other hand I generally think that likely langchain does not need to have save/load audio functionality - it is not core and maintenance burden. Can we drop it?
Also it is in wrong place - should be moved to utilities.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example it only works with wav files, someone serious with playing sounds should use dedicated library which can handle many different formats.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would left those functions in docs - to show how to do it, but would not necessary make them langchain functions, but maybe it is used by some agents so it needs to stay...