Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan committed Nov 28, 2023
1 parent 3e2c139 commit add9346
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 33 deletions.
41 changes: 19 additions & 22 deletions libs/langchain/langchain/document_loaders/audio.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
from typing import Any, Iterator, List
from __future__ import annotations

from langchain_core.documents.base import Document
from pathlib import Path
from typing import Any, Union

from langchain.document_loaders import Blob
from langchain.document_loaders.base import BaseLoader
from langchain.document_loaders.blob_loaders import FileSystemBlobLoader
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers.audio import AzureSpeechServiceParser


class AzureSpeechServiceLoader(BaseLoader):
def __init__(self, file_path: str, **kwargs: Any) -> None:
"""Initialize with file path."""
super().__init__()

self.file_path = file_path
self.parser = AzureSpeechServiceParser(**kwargs)

def load(self) -> List[Document]:
"""Eagerly load the content."""
return list(self.lazy_load())

def lazy_load(
self,
) -> Iterator[Document]:
"""Lazily lod documents."""
blob = Blob.from_path(self.file_path)
return iter(self.parser.parse(blob))
class AzureSpeechServiceLoader(GenericLoader):
@classmethod
def from_path(
cls, path: Union[str, Path], **kwargs: Any
) -> AzureSpeechServiceLoader:
path = path if isinstance(path, Path) else Path(path)
if path.is_file():
loader_params: dict = {"glob": path.name}
path = path.parent
else:
loader_params = {}
loader = FileSystemBlobLoader(path, **loader_params)
parser = AzureSpeechServiceParser(**kwargs)
return cls(loader, parser)
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,7 @@ class BlobLoader(ABC):
"""

@abstractmethod
def yield_blobs(
self,
) -> Iterable[Blob]:
def yield_blobs(self) -> Iterable[Blob]:
"""A lazy loader for raw data represented by LangChain's Blob object.
Returns:
Expand Down
4 changes: 1 addition & 3 deletions libs/langchain/langchain/document_loaders/parsers/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,9 @@ def lazy_parse(self, blob: Blob) -> Iterator[Document]:
except ImportError:
raise ImportError(
"azure.cognitiveservices.speech package not found, please install "
"it with `pip install azure-cognitiveservices-speech`"
"it with `pip install azure-cognitiveservices-speech`."
)

"""transcribes a conversation"""

def conversation_transcriber_recognition_canceled_cb(
evt: speechsdk.SessionEventArgs
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _get_csv_file_path() -> str:


def test_azure_speech_load_key_region_auto_detect_languages() -> None:
loader = AzureSpeechServiceLoader(
loader = AzureSpeechServiceLoader.from_path(
_get_csv_file_path(),
key=SPEECH_SERVICE_KEY,
region=SPEECH_SERVICE_REGION,
Expand All @@ -26,7 +26,7 @@ def test_azure_speech_load_key_region_auto_detect_languages() -> None:


def test_azure_speech_load_key_region_language() -> None:
loader = AzureSpeechServiceLoader(
loader = AzureSpeechServiceLoader.from_path(
_get_csv_file_path(),
key=SPEECH_SERVICE_KEY,
region=SPEECH_SERVICE_REGION,
Expand All @@ -37,18 +37,21 @@ def test_azure_speech_load_key_region_language() -> None:


def test_azure_speech_load_key_region() -> None:
loader = AzureSpeechServiceLoader(
loader = AzureSpeechServiceLoader.from_path(
_get_csv_file_path(), key=SPEECH_SERVICE_KEY, region=SPEECH_SERVICE_REGION
)
documents = loader.load()
assert "what" in documents[0].page_content.lower()


def test_azure_speech_load_key_endpoint() -> None:
loader = AzureSpeechServiceLoader(
loader = AzureSpeechServiceLoader.from_path(
_get_csv_file_path(),
key=SPEECH_SERVICE_KEY,
endpoint=f"wss://{SPEECH_SERVICE_REGION}.stt.speech.microsoft.com/speech/recognition/conversation/cognitiveservices/v1",
endpoint=(
f"wss://{SPEECH_SERVICE_REGION}.stt.speech.microsoft.com/speech/recognition"
"/conversation/cognitiveservices/v1",
),
)
documents = loader.load()
assert "what" in documents[0].page_content.lower()
Binary file not shown.

0 comments on commit add9346

Please sign in to comment.